ide_assists/handlers/
convert_match_to_let_else.rs

1use ide_db::defs::{Definition, NameRefClass};
2use syntax::{
3    AstNode, SyntaxNode,
4    ast::{self, HasName, Name, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
5    syntax_editor::SyntaxEditor,
6};
7
8use crate::{
9    AssistId,
10    assist_context::{AssistContext, Assists},
11};
12
13// Assist: convert_match_to_let_else
14//
15// Converts let statement with match initializer to let-else statement.
16//
17// ```
18// # //- minicore: option
19// fn foo(opt: Option<()>) {
20//     let val$0 = match opt {
21//         Some(it) => it,
22//         None => return,
23//     };
24// }
25// ```
26// ->
27// ```
28// fn foo(opt: Option<()>) {
29//     let Some(val) = opt else { return };
30// }
31// ```
32pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
33    let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
34    let pat = let_stmt.pat()?;
35    if ctx.offset() > pat.syntax().text_range().end() {
36        return None;
37    }
38
39    let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None };
40    let initializer_expr = initializer.expr()?;
41
42    let (extracting_arm, diverging_arm) = find_arms(ctx, &initializer)?;
43    if extracting_arm.guard().is_some() {
44        cov_mark::hit!(extracting_arm_has_guard);
45        return None;
46    }
47
48    let diverging_arm_expr = match diverging_arm.expr()?.dedent(1.into()) {
49        ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => {
50            block.to_string()
51        }
52        other => format!("{{ {other} }}"),
53    };
54    let extracting_arm_pat = extracting_arm.pat()?;
55    let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
56
57    acc.add(
58        AssistId::refactor_rewrite("convert_match_to_let_else"),
59        "Convert match to let-else",
60        let_stmt.syntax().text_range(),
61        |builder| {
62            let extracting_arm_pat =
63                rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat);
64            builder.replace(
65                let_stmt.syntax().text_range(),
66                format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
67            )
68        },
69    )
70}
71
72// Given a match expression, find extracting and diverging arms.
73fn find_arms(
74    ctx: &AssistContext<'_>,
75    match_expr: &ast::MatchExpr,
76) -> Option<(ast::MatchArm, ast::MatchArm)> {
77    let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
78    if arms.len() != 2 {
79        return None;
80    }
81
82    let mut extracting = None;
83    let mut diverging = None;
84    for arm in arms {
85        if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() {
86            diverging = Some(arm);
87        } else {
88            extracting = Some(arm);
89        }
90    }
91
92    match (extracting, diverging) {
93        (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
94        _ => {
95            cov_mark::hit!(non_diverging_match);
96            None
97        }
98    }
99}
100
101// Given an extracting arm, find the extracted variable.
102fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
103    match arm.expr()? {
104        ast::Expr::PathExpr(path) => {
105            let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
106            match NameRefClass::classify(&ctx.sema, &name_ref)? {
107                NameRefClass::Definition(Definition::Local(local), _) => {
108                    let source =
109                        local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
110                    source.collect()
111                }
112                _ => None,
113            }
114        }
115        _ => {
116            cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
117            None
118        }
119    }
120}
121
122// Rename `extracted` with `binding` in `pat`.
123fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
124    let syntax = pat.syntax().clone_subtree();
125    let mut editor = SyntaxEditor::new(syntax.clone());
126    let make = SyntaxFactory::with_mappings();
127    let extracted = extracted
128        .iter()
129        .map(|e| e.syntax().text_range() - pat.syntax().text_range().start())
130        .map(|r| syntax.covering_element(r))
131        .collect::<Vec<_>>();
132    for extracted_syntax in extracted {
133        // If `extracted` variable is a record field, we should rename it to `binding`,
134        // otherwise we just need to replace `extracted` with `binding`.
135        if let Some(record_pat_field) =
136            extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
137        {
138            if let Some(name_ref) = record_pat_field.field_name() {
139                editor.replace(
140                    record_pat_field.syntax(),
141                    make.record_pat_field(
142                        make.name_ref(&name_ref.text()),
143                        binding.clone_for_update(),
144                    )
145                    .syntax(),
146                );
147            }
148        } else {
149            editor.replace(extracted_syntax, binding.syntax().clone_for_update());
150        }
151    }
152    editor.add_mappings(make.finish_with_mappings());
153    let new_node = editor.finish().new_root().clone();
154    if let Some(pat) = ast::Pat::cast(new_node.clone()) {
155        pat.dedent(1.into()).syntax().clone()
156    } else {
157        new_node
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use crate::tests::{check_assist, check_assist_not_applicable};
164
165    use super::*;
166
167    #[test]
168    fn should_not_be_applicable_for_non_diverging_match() {
169        cov_mark::check!(non_diverging_match);
170        check_assist_not_applicable(
171            convert_match_to_let_else,
172            r#"
173//- minicore: option
174fn foo(opt: Option<()>) {
175    let val$0 = match opt {
176        Some(it) => it,
177        None => (),
178    };
179}
180"#,
181        );
182    }
183
184    #[test]
185    fn or_pattern_multiple_binding() {
186        check_assist(
187            convert_match_to_let_else,
188            r#"
189//- minicore: option
190enum Foo {
191    A(u32),
192    B(u32),
193    C(String),
194}
195
196fn foo(opt: Option<Foo>) -> Result<u32, ()> {
197    let va$0lue = match opt {
198        Some(Foo::A(it) | Foo::B(it)) => it,
199        _ => return Err(()),
200    };
201}
202    "#,
203            r#"
204enum Foo {
205    A(u32),
206    B(u32),
207    C(String),
208}
209
210fn foo(opt: Option<Foo>) -> Result<u32, ()> {
211    let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
212}
213    "#,
214        );
215    }
216
217    #[test]
218    fn indent_level() {
219        check_assist(
220            convert_match_to_let_else,
221            r#"
222//- minicore: option
223enum Foo {
224    A(u32),
225    B(u32),
226    C(String),
227}
228
229fn foo(opt: Option<Foo>) -> Result<u32, ()> {
230    let mut state = 2;
231    let va$0lue = match opt {
232        Some(
233            Foo::A(it)
234            | Foo::B(it)
235        ) => it,
236        _ => {
237            state = 3;
238            return Err(())
239        },
240    };
241}
242    "#,
243            r#"
244enum Foo {
245    A(u32),
246    B(u32),
247    C(String),
248}
249
250fn foo(opt: Option<Foo>) -> Result<u32, ()> {
251    let mut state = 2;
252    let Some(
253        Foo::A(value)
254        | Foo::B(value)
255    ) = opt else {
256        state = 3;
257        return Err(())
258    };
259}
260    "#,
261        );
262    }
263
264    #[test]
265    fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
266        cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
267        check_assist_not_applicable(
268            convert_match_to_let_else,
269            r#"
270//- minicore: option
271fn foo(opt: Option<i32>) {
272    let val$0 = match opt {
273        Some(it) => it + 1,
274        None => return,
275    };
276}
277"#,
278        );
279
280        check_assist_not_applicable(
281            convert_match_to_let_else,
282            r#"
283//- minicore: option
284fn foo(opt: Option<()>) {
285    let val$0 = match opt {
286        Some(it) => {
287            let _ = 1 + 1;
288            it
289        },
290        None => return,
291    };
292}
293"#,
294        );
295    }
296
297    #[test]
298    fn should_not_be_applicable_if_extracting_arm_has_guard() {
299        cov_mark::check!(extracting_arm_has_guard);
300        check_assist_not_applicable(
301            convert_match_to_let_else,
302            r#"
303//- minicore: option
304fn foo(opt: Option<()>) {
305    let val$0 = match opt {
306        Some(it) if 2 > 1 => it,
307        None => return,
308    };
309}
310"#,
311        );
312    }
313
314    #[test]
315    fn basic_pattern() {
316        check_assist(
317            convert_match_to_let_else,
318            r#"
319//- minicore: option
320fn foo(opt: Option<()>) {
321    let val$0 = match opt {
322        Some(it) => it,
323        None => return,
324    };
325}
326    "#,
327            r#"
328fn foo(opt: Option<()>) {
329    let Some(val) = opt else { return };
330}
331    "#,
332        );
333    }
334
335    #[test]
336    fn keeps_modifiers() {
337        check_assist(
338            convert_match_to_let_else,
339            r#"
340//- minicore: option
341fn foo(opt: Option<()>) {
342    let ref mut val$0 = match opt {
343        Some(it) => it,
344        None => return,
345    };
346}
347    "#,
348            r#"
349fn foo(opt: Option<()>) {
350    let Some(ref mut val) = opt else { return };
351}
352    "#,
353        );
354    }
355
356    #[test]
357    fn nested_pattern() {
358        check_assist(
359            convert_match_to_let_else,
360            r#"
361//- minicore: option, result
362fn foo(opt: Option<Result<()>>) {
363    let val$0 = match opt {
364        Some(Ok(it)) => it,
365        _ => return,
366    };
367}
368    "#,
369            r#"
370fn foo(opt: Option<Result<()>>) {
371    let Some(Ok(val)) = opt else { return };
372}
373    "#,
374        );
375    }
376
377    #[test]
378    fn works_with_any_diverging_block() {
379        check_assist(
380            convert_match_to_let_else,
381            r#"
382//- minicore: option
383fn foo(opt: Option<()>) {
384    loop {
385        let val$0 = match opt {
386            Some(it) => it,
387            None => break,
388        };
389    }
390}
391    "#,
392            r#"
393fn foo(opt: Option<()>) {
394    loop {
395        let Some(val) = opt else { break };
396    }
397}
398    "#,
399        );
400
401        check_assist(
402            convert_match_to_let_else,
403            r#"
404//- minicore: option
405fn foo(opt: Option<()>) {
406    loop {
407        let val$0 = match opt {
408            Some(it) => it,
409            None => continue,
410        };
411    }
412}
413    "#,
414            r#"
415fn foo(opt: Option<()>) {
416    loop {
417        let Some(val) = opt else { continue };
418    }
419}
420    "#,
421        );
422
423        check_assist(
424            convert_match_to_let_else,
425            r#"
426//- minicore: option
427fn panic() -> ! {}
428
429fn foo(opt: Option<()>) {
430    loop {
431        let val$0 = match opt {
432            Some(it) => it,
433            None => panic(),
434        };
435    }
436}
437    "#,
438            r#"
439fn panic() -> ! {}
440
441fn foo(opt: Option<()>) {
442    loop {
443        let Some(val) = opt else { panic() };
444    }
445}
446    "#,
447        );
448    }
449
450    #[test]
451    fn struct_pattern() {
452        check_assist(
453            convert_match_to_let_else,
454            r#"
455//- minicore: option
456struct Point {
457    x: i32,
458    y: i32,
459}
460
461fn foo(opt: Option<Point>) {
462    let val$0 = match opt {
463        Some(Point { x: 0, y }) => y,
464        _ => return,
465    };
466}
467    "#,
468            r#"
469struct Point {
470    x: i32,
471    y: i32,
472}
473
474fn foo(opt: Option<Point>) {
475    let Some(Point { x: 0, y: val }) = opt else { return };
476}
477    "#,
478        );
479    }
480
481    #[test]
482    fn renames_whole_binding() {
483        check_assist(
484            convert_match_to_let_else,
485            r#"
486//- minicore: option
487fn foo(opt: Option<i32>) -> Option<i32> {
488    let val$0 = match opt {
489        it @ Some(42) => it,
490        _ => return None,
491    };
492    val
493}
494    "#,
495            r#"
496fn foo(opt: Option<i32>) -> Option<i32> {
497    let val @ Some(42) = opt else { return None };
498    val
499}
500    "#,
501        );
502    }
503
504    #[test]
505    fn complex_pattern() {
506        check_assist(
507            convert_match_to_let_else,
508            r#"
509//- minicore: option
510fn f() {
511    let (x, y)$0 = match Some((0, 1)) {
512        Some(it) => it,
513        None => return,
514    };
515}
516"#,
517            r#"
518fn f() {
519    let Some((x, y)) = Some((0, 1)) else { return };
520}
521"#,
522        );
523    }
524
525    #[test]
526    fn diverging_block() {
527        check_assist(
528            convert_match_to_let_else,
529            r#"
530//- minicore: option
531fn f() {
532    let x$0 = match Some(()) {
533        Some(it) => it,
534        None => {//comment
535            println!("nope");
536            return
537        },
538    };
539}
540"#,
541            r#"
542fn f() {
543    let Some(x) = Some(()) else {//comment
544        println!("nope");
545        return
546    };
547}
548"#,
549        );
550    }
551}