Skip to main content

ide_assists/handlers/
convert_match_to_let_else.rs

1use ide_db::defs::{Definition, NameRefClass};
2use syntax::{
3    AstNode, SyntaxNode,
4    ast::{self, HasName, Name, edit::AstNodeEdit},
5    syntax_editor::SyntaxEditor,
6};
7
8use crate::{
9    AssistId,
10    assist_context::{AssistContext, Assists},
11};
12
13// Assist: convert_match_to_let_else
14//
15// Converts let statement with match initializer to let-else statement.
16//
17// ```
18// # //- minicore: option
19// fn foo(opt: Option<()>) {
20//     let val$0 = match opt {
21//         Some(it) => it,
22//         None => return,
23//     };
24// }
25// ```
26// ->
27// ```
28// fn foo(opt: Option<()>) {
29//     let Some(val) = opt else { return };
30// }
31// ```
32pub(crate) fn convert_match_to_let_else(
33    acc: &mut Assists,
34    ctx: &AssistContext<'_, '_>,
35) -> Option<()> {
36    let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
37    let pat = let_stmt.pat()?;
38    if ctx.offset() > pat.syntax().text_range().end() {
39        return None;
40    }
41
42    let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None };
43    let initializer_expr = initializer.expr()?;
44
45    let (extracting_arm, diverging_arm) = find_arms(ctx, &initializer)?;
46    if extracting_arm.guard().is_some() {
47        cov_mark::hit!(extracting_arm_has_guard);
48        return None;
49    }
50
51    let diverging_arm_expr = match diverging_arm.expr()?.dedent(1.into()) {
52        ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => {
53            block.to_string()
54        }
55        other => format!("{{ {other} }}"),
56    };
57    let extracting_arm_pat = extracting_arm.pat()?;
58    let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
59
60    acc.add(
61        AssistId::refactor_rewrite("convert_match_to_let_else"),
62        "Convert match to let-else",
63        let_stmt.syntax().text_range(),
64        |builder| {
65            let extracting_arm_pat =
66                rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat);
67            let (open_paren, close_paren) = if ast::OrPat::can_cast(extracting_arm_pat.kind()) {
68                // Or patterns cannot put put directly under let statements.
69                // FIXME: Do this with `SyntaxEditor` in `rename_variable()`, it's just difficult right now
70                // since it re-roots nodes.
71                ("(", ")")
72            } else {
73                ("", "")
74            };
75            builder.replace(
76                let_stmt.syntax().text_range(),
77                format!("let {open_paren}{extracting_arm_pat}{close_paren} = {initializer_expr} else {diverging_arm_expr};"),
78            )
79        },
80    )
81}
82
83// Given a match expression, find extracting and diverging arms.
84fn find_arms(
85    ctx: &AssistContext<'_, '_>,
86    match_expr: &ast::MatchExpr,
87) -> Option<(ast::MatchArm, ast::MatchArm)> {
88    let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
89    if arms.len() != 2 {
90        return None;
91    }
92
93    let mut extracting = None;
94    let mut diverging = None;
95    for arm in arms {
96        if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() {
97            diverging = Some(arm);
98        } else {
99            extracting = Some(arm);
100        }
101    }
102
103    match (extracting, diverging) {
104        (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
105        _ => {
106            cov_mark::hit!(non_diverging_match);
107            None
108        }
109    }
110}
111
112// Given an extracting arm, find the extracted variable.
113fn find_extracted_variable(ctx: &AssistContext<'_, '_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
114    match arm.expr()? {
115        ast::Expr::PathExpr(path) => {
116            let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
117            match NameRefClass::classify(&ctx.sema, &name_ref)? {
118                NameRefClass::Definition(Definition::Local(local), _) => {
119                    let source =
120                        local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
121                    source.collect()
122                }
123                _ => None,
124            }
125        }
126        _ => {
127            cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
128            None
129        }
130    }
131}
132
133// Rename `extracted` with `binding` in `pat`.
134fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
135    let (editor, syntax) = SyntaxEditor::new(pat.syntax().clone());
136    let make = editor.make();
137    let extracted = extracted
138        .iter()
139        .map(|e| e.syntax().text_range() - pat.syntax().text_range().start())
140        .map(|r| syntax.covering_element(r))
141        .collect::<Vec<_>>();
142    for extracted_syntax in extracted {
143        // If `extracted` variable is a record field, we should rename it to `binding`,
144        // otherwise we just need to replace `extracted` with `binding`.
145        if let Some(record_pat_field) =
146            extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
147        {
148            if let Some(name_ref) = record_pat_field.field_name() {
149                editor.replace(
150                    record_pat_field.syntax(),
151                    make.record_pat_field(make.name_ref(&name_ref.text()), binding.clone())
152                        .syntax(),
153                );
154            }
155        } else {
156            editor.replace(extracted_syntax, binding.syntax());
157        }
158    }
159    let new_node = editor.finish().new_root().clone();
160    if let Some(pat) = ast::Pat::cast(new_node.clone()) {
161        pat.dedent(1.into()).syntax().clone()
162    } else {
163        new_node
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::tests::{check_assist, check_assist_not_applicable};
170
171    use super::*;
172
173    #[test]
174    fn should_not_be_applicable_for_non_diverging_match() {
175        cov_mark::check!(non_diverging_match);
176        check_assist_not_applicable(
177            convert_match_to_let_else,
178            r#"
179//- minicore: option
180fn foo(opt: Option<()>) {
181    let val$0 = match opt {
182        Some(it) => it,
183        None => (),
184    };
185}
186"#,
187        );
188    }
189
190    #[test]
191    fn or_pattern_multiple_binding() {
192        check_assist(
193            convert_match_to_let_else,
194            r#"
195//- minicore: option
196enum Foo {
197    A(u32),
198    B(u32),
199    C(String),
200}
201
202fn foo(opt: Option<Foo>) -> Result<u32, ()> {
203    let va$0lue = match opt {
204        Some(Foo::A(it) | Foo::B(it)) => it,
205        _ => return Err(()),
206    };
207}
208    "#,
209            r#"
210enum Foo {
211    A(u32),
212    B(u32),
213    C(String),
214}
215
216fn foo(opt: Option<Foo>) -> Result<u32, ()> {
217    let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
218}
219    "#,
220        );
221    }
222
223    #[test]
224    fn indent_level() {
225        check_assist(
226            convert_match_to_let_else,
227            r#"
228//- minicore: option
229enum Foo {
230    A(u32),
231    B(u32),
232    C(String),
233}
234
235fn foo(opt: Option<Foo>) -> Result<u32, ()> {
236    let mut state = 2;
237    let va$0lue = match opt {
238        Some(
239            Foo::A(it)
240            | Foo::B(it)
241        ) => it,
242        _ => {
243            state = 3;
244            return Err(())
245        },
246    };
247}
248    "#,
249            r#"
250enum Foo {
251    A(u32),
252    B(u32),
253    C(String),
254}
255
256fn foo(opt: Option<Foo>) -> Result<u32, ()> {
257    let mut state = 2;
258    let Some(
259        Foo::A(value)
260        | Foo::B(value)
261    ) = opt else {
262        state = 3;
263        return Err(())
264    };
265}
266    "#,
267        );
268    }
269
270    #[test]
271    fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
272        cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
273        check_assist_not_applicable(
274            convert_match_to_let_else,
275            r#"
276//- minicore: option
277fn foo(opt: Option<i32>) {
278    let val$0 = match opt {
279        Some(it) => it + 1,
280        None => return,
281    };
282}
283"#,
284        );
285
286        check_assist_not_applicable(
287            convert_match_to_let_else,
288            r#"
289//- minicore: option
290fn foo(opt: Option<()>) {
291    let val$0 = match opt {
292        Some(it) => {
293            let _ = 1 + 1;
294            it
295        },
296        None => return,
297    };
298}
299"#,
300        );
301    }
302
303    #[test]
304    fn should_not_be_applicable_if_extracting_arm_has_guard() {
305        cov_mark::check!(extracting_arm_has_guard);
306        check_assist_not_applicable(
307            convert_match_to_let_else,
308            r#"
309//- minicore: option
310fn foo(opt: Option<()>) {
311    let val$0 = match opt {
312        Some(it) if 2 > 1 => it,
313        None => return,
314    };
315}
316"#,
317        );
318    }
319
320    #[test]
321    fn basic_pattern() {
322        check_assist(
323            convert_match_to_let_else,
324            r#"
325//- minicore: option
326fn foo(opt: Option<()>) {
327    let val$0 = match opt {
328        Some(it) => it,
329        None => return,
330    };
331}
332    "#,
333            r#"
334fn foo(opt: Option<()>) {
335    let Some(val) = opt else { return };
336}
337    "#,
338        );
339    }
340
341    #[test]
342    fn keeps_modifiers() {
343        check_assist(
344            convert_match_to_let_else,
345            r#"
346//- minicore: option
347fn foo(opt: Option<()>) {
348    let ref mut val$0 = match opt {
349        Some(it) => it,
350        None => return,
351    };
352}
353    "#,
354            r#"
355fn foo(opt: Option<()>) {
356    let Some(ref mut val) = opt else { return };
357}
358    "#,
359        );
360    }
361
362    #[test]
363    fn nested_pattern() {
364        check_assist(
365            convert_match_to_let_else,
366            r#"
367//- minicore: option, result
368fn foo(opt: Option<Result<()>>) {
369    let val$0 = match opt {
370        Some(Ok(it)) => it,
371        _ => return,
372    };
373}
374    "#,
375            r#"
376fn foo(opt: Option<Result<()>>) {
377    let Some(Ok(val)) = opt else { return };
378}
379    "#,
380        );
381    }
382
383    #[test]
384    fn works_with_any_diverging_block() {
385        check_assist(
386            convert_match_to_let_else,
387            r#"
388//- minicore: option
389fn foo(opt: Option<()>) {
390    loop {
391        let val$0 = match opt {
392            Some(it) => it,
393            None => break,
394        };
395    }
396}
397    "#,
398            r#"
399fn foo(opt: Option<()>) {
400    loop {
401        let Some(val) = opt else { break };
402    }
403}
404    "#,
405        );
406
407        check_assist(
408            convert_match_to_let_else,
409            r#"
410//- minicore: option
411fn foo(opt: Option<()>) {
412    loop {
413        let val$0 = match opt {
414            Some(it) => it,
415            None => continue,
416        };
417    }
418}
419    "#,
420            r#"
421fn foo(opt: Option<()>) {
422    loop {
423        let Some(val) = opt else { continue };
424    }
425}
426    "#,
427        );
428
429        check_assist(
430            convert_match_to_let_else,
431            r#"
432//- minicore: option
433fn panic() -> ! {}
434
435fn foo(opt: Option<()>) {
436    loop {
437        let val$0 = match opt {
438            Some(it) => it,
439            None => panic(),
440        };
441    }
442}
443    "#,
444            r#"
445fn panic() -> ! {}
446
447fn foo(opt: Option<()>) {
448    loop {
449        let Some(val) = opt else { panic() };
450    }
451}
452    "#,
453        );
454    }
455
456    #[test]
457    fn struct_pattern() {
458        check_assist(
459            convert_match_to_let_else,
460            r#"
461//- minicore: option
462struct Point {
463    x: i32,
464    y: i32,
465}
466
467fn foo(opt: Option<Point>) {
468    let val$0 = match opt {
469        Some(Point { x: 0, y }) => y,
470        _ => return,
471    };
472}
473    "#,
474            r#"
475struct Point {
476    x: i32,
477    y: i32,
478}
479
480fn foo(opt: Option<Point>) {
481    let Some(Point { x: 0, y: val }) = opt else { return };
482}
483    "#,
484        );
485    }
486
487    #[test]
488    fn renames_whole_binding() {
489        check_assist(
490            convert_match_to_let_else,
491            r#"
492//- minicore: option
493fn foo(opt: Option<i32>) -> Option<i32> {
494    let val$0 = match opt {
495        it @ Some(42) => it,
496        _ => return None,
497    };
498    val
499}
500    "#,
501            r#"
502fn foo(opt: Option<i32>) -> Option<i32> {
503    let val @ Some(42) = opt else { return None };
504    val
505}
506    "#,
507        );
508    }
509
510    #[test]
511    fn complex_pattern() {
512        check_assist(
513            convert_match_to_let_else,
514            r#"
515//- minicore: option
516fn f() {
517    let (x, y)$0 = match Some((0, 1)) {
518        Some(it) => it,
519        None => return,
520    };
521}
522"#,
523            r#"
524fn f() {
525    let Some((x, y)) = Some((0, 1)) else { return };
526}
527"#,
528        );
529    }
530
531    #[test]
532    fn diverging_block() {
533        check_assist(
534            convert_match_to_let_else,
535            r#"
536//- minicore: option
537fn f() {
538    let x$0 = match Some(()) {
539        Some(it) => it,
540        None => {//comment
541            println!("nope");
542            return
543        },
544    };
545}
546"#,
547            r#"
548fn f() {
549    let Some(x) = Some(()) else {//comment
550        println!("nope");
551        return
552    };
553}
554"#,
555        );
556    }
557
558    #[test]
559    fn top_level_or_pat() {
560        check_assist(
561            convert_match_to_let_else,
562            r#"
563enum E {
564    A(u32),
565    B(u32),
566    C,
567}
568
569fn foo() {
570    let e = E::A(0);
571    let _$0 = match e {
572        E::A(v) | E::B(v) => v,
573        _ => return,
574    };
575}
576        "#,
577            r#"
578enum E {
579    A(u32),
580    B(u32),
581    C,
582}
583
584fn foo() {
585    let e = E::A(0);
586    let (E::A(_) | E::B(_)) = e else { return };
587}
588        "#,
589        );
590    }
591}