ide_assists/handlers/
convert_two_arm_bool_match_to_matches_macro.rs

1use hir::Semantics;
2use ide_db::RootDatabase;
3use stdx::format_to;
4use syntax::ast::{self, AstNode};
5
6use crate::{AssistContext, AssistId, Assists};
7
8// Assist: convert_two_arm_bool_match_to_matches_macro
9//
10// Convert 2-arm match that evaluates to a boolean into the equivalent matches! invocation.
11//
12// ```
13// fn main() {
14//     match scrutinee$0 {
15//         Some(val) if val.cond() => true,
16//         _ => false,
17//     }
18// }
19// ```
20// ->
21// ```
22// fn main() {
23//     matches!(scrutinee, Some(val) if val.cond())
24// }
25// ```
26pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
27    acc: &mut Assists,
28    ctx: &AssistContext<'_>,
29) -> Option<()> {
30    use ArmBodyExpression::*;
31    let match_expr = ctx.find_node_at_offset::<ast::MatchExpr>()?;
32    let match_arm_list = match_expr.match_arm_list()?;
33    let mut arms = match_arm_list.arms();
34    let first_arm = arms.next()?;
35    let second_arm = arms.next()?;
36    if arms.next().is_some() {
37        cov_mark::hit!(non_two_arm_match);
38        return None;
39    }
40    let first_arm_expr = first_arm.expr()?;
41    let second_arm_expr = second_arm.expr()?;
42    let first_arm_body = is_bool_literal_expr(&ctx.sema, &first_arm_expr)?;
43    let second_arm_body = is_bool_literal_expr(&ctx.sema, &second_arm_expr)?;
44
45    if !matches!(
46        (&first_arm_body, &second_arm_body),
47        (Literal(true), Literal(false))
48            | (Literal(false), Literal(true))
49            | (Expression(_), Literal(false))
50    ) {
51        cov_mark::hit!(non_invert_bool_literal_arms);
52        return None;
53    }
54
55    let target_range = ctx.sema.original_range(match_expr.syntax()).range;
56    let expr = match_expr.expr()?;
57
58    acc.add(
59        AssistId::refactor_rewrite("convert_two_arm_bool_match_to_matches_macro"),
60        "Convert to matches!",
61        target_range,
62        |builder| {
63            let mut arm_str = String::new();
64            if let Some(pat) = &first_arm.pat() {
65                format_to!(arm_str, "{pat}");
66            }
67            if let Some(guard) = &first_arm.guard() {
68                arm_str += &format!(" {guard}");
69            }
70
71            let replace_with = match (first_arm_body, second_arm_body) {
72                (Literal(true), Literal(false)) => {
73                    format!("matches!({expr}, {arm_str})")
74                }
75                (Literal(false), Literal(true)) => {
76                    format!("!matches!({expr}, {arm_str})")
77                }
78                (Expression(body_expr), Literal(false)) => {
79                    arm_str.push_str(match &first_arm.guard() {
80                        Some(_) => " && ",
81                        _ => " if ",
82                    });
83                    format!("matches!({expr}, {arm_str}{body_expr})")
84                }
85                _ => {
86                    unreachable!()
87                }
88            };
89            builder.replace(target_range, replace_with);
90        },
91    )
92}
93
94enum ArmBodyExpression {
95    Literal(bool),
96    Expression(ast::Expr),
97}
98
99fn is_bool_literal_expr(
100    sema: &Semantics<'_, RootDatabase>,
101    expr: &ast::Expr,
102) -> Option<ArmBodyExpression> {
103    if let ast::Expr::Literal(lit) = expr
104        && let ast::LiteralKind::Bool(b) = lit.kind()
105    {
106        return Some(ArmBodyExpression::Literal(b));
107    }
108
109    if !sema.type_of_expr(expr)?.original.is_bool() {
110        return None;
111    }
112
113    Some(ArmBodyExpression::Expression(expr.clone()))
114}
115
116#[cfg(test)]
117mod tests {
118    use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
119
120    use super::convert_two_arm_bool_match_to_matches_macro;
121
122    #[test]
123    fn not_applicable_outside_of_range_left() {
124        check_assist_not_applicable(
125            convert_two_arm_bool_match_to_matches_macro,
126            r#"
127fn foo(a: Option<u32>) -> bool {
128    $0 match a {
129        Some(_val) => true,
130        _ => false
131    }
132}
133        "#,
134        );
135    }
136
137    #[test]
138    fn not_applicable_non_two_arm_match() {
139        cov_mark::check!(non_two_arm_match);
140        check_assist_not_applicable(
141            convert_two_arm_bool_match_to_matches_macro,
142            r#"
143fn foo(a: Option<u32>) -> bool {
144    match a$0 {
145        Some(3) => true,
146        Some(4) => true,
147        _ => false
148    }
149}
150        "#,
151        );
152    }
153
154    #[test]
155    fn not_applicable_both_false_arms() {
156        cov_mark::check!(non_invert_bool_literal_arms);
157        check_assist_not_applicable(
158            convert_two_arm_bool_match_to_matches_macro,
159            r#"
160fn foo(a: Option<u32>) -> bool {
161    match a$0 {
162        Some(val) => false,
163        _ => false
164    }
165}
166        "#,
167        );
168    }
169
170    #[test]
171    fn not_applicable_both_true_arms() {
172        cov_mark::check!(non_invert_bool_literal_arms);
173        check_assist_not_applicable(
174            convert_two_arm_bool_match_to_matches_macro,
175            r#"
176fn foo(a: Option<u32>) -> bool {
177    match a$0 {
178        Some(val) => true,
179        _ => true
180    }
181}
182        "#,
183        );
184    }
185
186    #[test]
187    fn convert_simple_case() {
188        check_assist(
189            convert_two_arm_bool_match_to_matches_macro,
190            r#"
191fn foo(a: Option<u32>) -> bool {
192    match a$0 {
193        Some(_val) => true,
194        _ => false
195    }
196}
197"#,
198            r#"
199fn foo(a: Option<u32>) -> bool {
200    matches!(a, Some(_val))
201}
202"#,
203        );
204    }
205
206    #[test]
207    fn convert_simple_invert_case() {
208        check_assist(
209            convert_two_arm_bool_match_to_matches_macro,
210            r#"
211fn foo(a: Option<u32>) -> bool {
212    match a$0 {
213        Some(_val) => false,
214        _ => true
215    }
216}
217"#,
218            r#"
219fn foo(a: Option<u32>) -> bool {
220    !matches!(a, Some(_val))
221}
222"#,
223        );
224    }
225
226    #[test]
227    fn convert_with_guard_case() {
228        check_assist(
229            convert_two_arm_bool_match_to_matches_macro,
230            r#"
231fn foo(a: Option<u32>) -> bool {
232    match a$0 {
233        Some(val) if val > 3 => true,
234        _ => false
235    }
236}
237"#,
238            r#"
239fn foo(a: Option<u32>) -> bool {
240    matches!(a, Some(val) if val > 3)
241}
242"#,
243        );
244    }
245
246    #[test]
247    fn convert_enum_match_cases() {
248        check_assist(
249            convert_two_arm_bool_match_to_matches_macro,
250            r#"
251enum X { A, B }
252
253fn foo(a: X) -> bool {
254    match a$0 {
255        X::A => true,
256        _ => false
257    }
258}
259"#,
260            r#"
261enum X { A, B }
262
263fn foo(a: X) -> bool {
264    matches!(a, X::A)
265}
266"#,
267        );
268    }
269
270    #[test]
271    fn convert_target_simple() {
272        check_assist_target(
273            convert_two_arm_bool_match_to_matches_macro,
274            r#"
275fn foo(a: Option<u32>) -> bool {
276    match a$0 {
277        Some(val) => true,
278        _ => false
279    }
280}
281"#,
282            r#"match a {
283        Some(val) => true,
284        _ => false
285    }"#,
286        );
287    }
288
289    #[test]
290    fn convert_target_complex() {
291        check_assist_target(
292            convert_two_arm_bool_match_to_matches_macro,
293            r#"
294enum E { X, Y }
295
296fn main() {
297    match E::X$0 {
298        E::X => true,
299        _ => false,
300    }
301}
302"#,
303            "match E::X {
304        E::X => true,
305        _ => false,
306    }",
307        );
308    }
309
310    #[test]
311    fn convert_non_literal_bool() {
312        check_assist(
313            convert_two_arm_bool_match_to_matches_macro,
314            r#"
315fn main() {
316    match 0$0 {
317        a @ 0..15 => a == 0,
318        _ => false,
319    }
320}
321"#,
322            r#"
323fn main() {
324    matches!(0, a @ 0..15 if a == 0)
325}
326"#,
327        );
328        check_assist(
329            convert_two_arm_bool_match_to_matches_macro,
330            r#"
331fn main() {
332    match 0$0 {
333        a @ 0..15 if thing() => a == 0,
334        _ => false,
335    }
336}
337"#,
338            r#"
339fn main() {
340    matches!(0, a @ 0..15 if thing() && a == 0)
341}
342"#,
343        );
344    }
345}