1use ide_db::defs::{Definition, NameRefClass};
2use syntax::{
3 AstNode, SyntaxNode,
4 ast::{self, HasName, Name, edit::AstNodeEdit},
5 syntax_editor::SyntaxEditor,
6};
7
8use crate::{
9 AssistId,
10 assist_context::{AssistContext, Assists},
11};
12
13pub(crate) fn convert_match_to_let_else(
33 acc: &mut Assists,
34 ctx: &AssistContext<'_, '_>,
35) -> Option<()> {
36 let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
37 let pat = let_stmt.pat()?;
38 if ctx.offset() > pat.syntax().text_range().end() {
39 return None;
40 }
41
42 let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None };
43 let initializer_expr = initializer.expr()?;
44
45 let (extracting_arm, diverging_arm) = find_arms(ctx, &initializer)?;
46 if extracting_arm.guard().is_some() {
47 cov_mark::hit!(extracting_arm_has_guard);
48 return None;
49 }
50
51 let diverging_arm_expr = match diverging_arm.expr()?.dedent(1.into()) {
52 ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => {
53 block.to_string()
54 }
55 other => format!("{{ {other} }}"),
56 };
57 let extracting_arm_pat = extracting_arm.pat()?;
58 let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
59
60 acc.add(
61 AssistId::refactor_rewrite("convert_match_to_let_else"),
62 "Convert match to let-else",
63 let_stmt.syntax().text_range(),
64 |builder| {
65 let extracting_arm_pat =
66 rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat);
67 let (open_paren, close_paren) = if ast::OrPat::can_cast(extracting_arm_pat.kind()) {
68 ("(", ")")
72 } else {
73 ("", "")
74 };
75 builder.replace(
76 let_stmt.syntax().text_range(),
77 format!("let {open_paren}{extracting_arm_pat}{close_paren} = {initializer_expr} else {diverging_arm_expr};"),
78 )
79 },
80 )
81}
82
83fn find_arms(
85 ctx: &AssistContext<'_, '_>,
86 match_expr: &ast::MatchExpr,
87) -> Option<(ast::MatchArm, ast::MatchArm)> {
88 let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
89 if arms.len() != 2 {
90 return None;
91 }
92
93 let mut extracting = None;
94 let mut diverging = None;
95 for arm in arms {
96 if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() {
97 diverging = Some(arm);
98 } else {
99 extracting = Some(arm);
100 }
101 }
102
103 match (extracting, diverging) {
104 (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
105 _ => {
106 cov_mark::hit!(non_diverging_match);
107 None
108 }
109 }
110}
111
112fn find_extracted_variable(ctx: &AssistContext<'_, '_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
114 match arm.expr()? {
115 ast::Expr::PathExpr(path) => {
116 let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
117 match NameRefClass::classify(&ctx.sema, &name_ref)? {
118 NameRefClass::Definition(Definition::Local(local), _) => {
119 let source =
120 local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
121 source.collect()
122 }
123 _ => None,
124 }
125 }
126 _ => {
127 cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
128 None
129 }
130 }
131}
132
133fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
135 let (editor, syntax) = SyntaxEditor::new(pat.syntax().clone());
136 let make = editor.make();
137 let extracted = extracted
138 .iter()
139 .map(|e| e.syntax().text_range() - pat.syntax().text_range().start())
140 .map(|r| syntax.covering_element(r))
141 .collect::<Vec<_>>();
142 for extracted_syntax in extracted {
143 if let Some(record_pat_field) =
146 extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
147 {
148 if let Some(name_ref) = record_pat_field.field_name() {
149 editor.replace(
150 record_pat_field.syntax(),
151 make.record_pat_field(make.name_ref(&name_ref.text()), binding.clone())
152 .syntax(),
153 );
154 }
155 } else {
156 editor.replace(extracted_syntax, binding.syntax());
157 }
158 }
159 let new_node = editor.finish().new_root().clone();
160 if let Some(pat) = ast::Pat::cast(new_node.clone()) {
161 pat.dedent(1.into()).syntax().clone()
162 } else {
163 new_node
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use crate::tests::{check_assist, check_assist_not_applicable};
170
171 use super::*;
172
173 #[test]
174 fn should_not_be_applicable_for_non_diverging_match() {
175 cov_mark::check!(non_diverging_match);
176 check_assist_not_applicable(
177 convert_match_to_let_else,
178 r#"
179//- minicore: option
180fn foo(opt: Option<()>) {
181 let val$0 = match opt {
182 Some(it) => it,
183 None => (),
184 };
185}
186"#,
187 );
188 }
189
190 #[test]
191 fn or_pattern_multiple_binding() {
192 check_assist(
193 convert_match_to_let_else,
194 r#"
195//- minicore: option
196enum Foo {
197 A(u32),
198 B(u32),
199 C(String),
200}
201
202fn foo(opt: Option<Foo>) -> Result<u32, ()> {
203 let va$0lue = match opt {
204 Some(Foo::A(it) | Foo::B(it)) => it,
205 _ => return Err(()),
206 };
207}
208 "#,
209 r#"
210enum Foo {
211 A(u32),
212 B(u32),
213 C(String),
214}
215
216fn foo(opt: Option<Foo>) -> Result<u32, ()> {
217 let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
218}
219 "#,
220 );
221 }
222
223 #[test]
224 fn indent_level() {
225 check_assist(
226 convert_match_to_let_else,
227 r#"
228//- minicore: option
229enum Foo {
230 A(u32),
231 B(u32),
232 C(String),
233}
234
235fn foo(opt: Option<Foo>) -> Result<u32, ()> {
236 let mut state = 2;
237 let va$0lue = match opt {
238 Some(
239 Foo::A(it)
240 | Foo::B(it)
241 ) => it,
242 _ => {
243 state = 3;
244 return Err(())
245 },
246 };
247}
248 "#,
249 r#"
250enum Foo {
251 A(u32),
252 B(u32),
253 C(String),
254}
255
256fn foo(opt: Option<Foo>) -> Result<u32, ()> {
257 let mut state = 2;
258 let Some(
259 Foo::A(value)
260 | Foo::B(value)
261 ) = opt else {
262 state = 3;
263 return Err(())
264 };
265}
266 "#,
267 );
268 }
269
270 #[test]
271 fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
272 cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
273 check_assist_not_applicable(
274 convert_match_to_let_else,
275 r#"
276//- minicore: option
277fn foo(opt: Option<i32>) {
278 let val$0 = match opt {
279 Some(it) => it + 1,
280 None => return,
281 };
282}
283"#,
284 );
285
286 check_assist_not_applicable(
287 convert_match_to_let_else,
288 r#"
289//- minicore: option
290fn foo(opt: Option<()>) {
291 let val$0 = match opt {
292 Some(it) => {
293 let _ = 1 + 1;
294 it
295 },
296 None => return,
297 };
298}
299"#,
300 );
301 }
302
303 #[test]
304 fn should_not_be_applicable_if_extracting_arm_has_guard() {
305 cov_mark::check!(extracting_arm_has_guard);
306 check_assist_not_applicable(
307 convert_match_to_let_else,
308 r#"
309//- minicore: option
310fn foo(opt: Option<()>) {
311 let val$0 = match opt {
312 Some(it) if 2 > 1 => it,
313 None => return,
314 };
315}
316"#,
317 );
318 }
319
320 #[test]
321 fn basic_pattern() {
322 check_assist(
323 convert_match_to_let_else,
324 r#"
325//- minicore: option
326fn foo(opt: Option<()>) {
327 let val$0 = match opt {
328 Some(it) => it,
329 None => return,
330 };
331}
332 "#,
333 r#"
334fn foo(opt: Option<()>) {
335 let Some(val) = opt else { return };
336}
337 "#,
338 );
339 }
340
341 #[test]
342 fn keeps_modifiers() {
343 check_assist(
344 convert_match_to_let_else,
345 r#"
346//- minicore: option
347fn foo(opt: Option<()>) {
348 let ref mut val$0 = match opt {
349 Some(it) => it,
350 None => return,
351 };
352}
353 "#,
354 r#"
355fn foo(opt: Option<()>) {
356 let Some(ref mut val) = opt else { return };
357}
358 "#,
359 );
360 }
361
362 #[test]
363 fn nested_pattern() {
364 check_assist(
365 convert_match_to_let_else,
366 r#"
367//- minicore: option, result
368fn foo(opt: Option<Result<()>>) {
369 let val$0 = match opt {
370 Some(Ok(it)) => it,
371 _ => return,
372 };
373}
374 "#,
375 r#"
376fn foo(opt: Option<Result<()>>) {
377 let Some(Ok(val)) = opt else { return };
378}
379 "#,
380 );
381 }
382
383 #[test]
384 fn works_with_any_diverging_block() {
385 check_assist(
386 convert_match_to_let_else,
387 r#"
388//- minicore: option
389fn foo(opt: Option<()>) {
390 loop {
391 let val$0 = match opt {
392 Some(it) => it,
393 None => break,
394 };
395 }
396}
397 "#,
398 r#"
399fn foo(opt: Option<()>) {
400 loop {
401 let Some(val) = opt else { break };
402 }
403}
404 "#,
405 );
406
407 check_assist(
408 convert_match_to_let_else,
409 r#"
410//- minicore: option
411fn foo(opt: Option<()>) {
412 loop {
413 let val$0 = match opt {
414 Some(it) => it,
415 None => continue,
416 };
417 }
418}
419 "#,
420 r#"
421fn foo(opt: Option<()>) {
422 loop {
423 let Some(val) = opt else { continue };
424 }
425}
426 "#,
427 );
428
429 check_assist(
430 convert_match_to_let_else,
431 r#"
432//- minicore: option
433fn panic() -> ! {}
434
435fn foo(opt: Option<()>) {
436 loop {
437 let val$0 = match opt {
438 Some(it) => it,
439 None => panic(),
440 };
441 }
442}
443 "#,
444 r#"
445fn panic() -> ! {}
446
447fn foo(opt: Option<()>) {
448 loop {
449 let Some(val) = opt else { panic() };
450 }
451}
452 "#,
453 );
454 }
455
456 #[test]
457 fn struct_pattern() {
458 check_assist(
459 convert_match_to_let_else,
460 r#"
461//- minicore: option
462struct Point {
463 x: i32,
464 y: i32,
465}
466
467fn foo(opt: Option<Point>) {
468 let val$0 = match opt {
469 Some(Point { x: 0, y }) => y,
470 _ => return,
471 };
472}
473 "#,
474 r#"
475struct Point {
476 x: i32,
477 y: i32,
478}
479
480fn foo(opt: Option<Point>) {
481 let Some(Point { x: 0, y: val }) = opt else { return };
482}
483 "#,
484 );
485 }
486
487 #[test]
488 fn renames_whole_binding() {
489 check_assist(
490 convert_match_to_let_else,
491 r#"
492//- minicore: option
493fn foo(opt: Option<i32>) -> Option<i32> {
494 let val$0 = match opt {
495 it @ Some(42) => it,
496 _ => return None,
497 };
498 val
499}
500 "#,
501 r#"
502fn foo(opt: Option<i32>) -> Option<i32> {
503 let val @ Some(42) = opt else { return None };
504 val
505}
506 "#,
507 );
508 }
509
510 #[test]
511 fn complex_pattern() {
512 check_assist(
513 convert_match_to_let_else,
514 r#"
515//- minicore: option
516fn f() {
517 let (x, y)$0 = match Some((0, 1)) {
518 Some(it) => it,
519 None => return,
520 };
521}
522"#,
523 r#"
524fn f() {
525 let Some((x, y)) = Some((0, 1)) else { return };
526}
527"#,
528 );
529 }
530
531 #[test]
532 fn diverging_block() {
533 check_assist(
534 convert_match_to_let_else,
535 r#"
536//- minicore: option
537fn f() {
538 let x$0 = match Some(()) {
539 Some(it) => it,
540 None => {//comment
541 println!("nope");
542 return
543 },
544 };
545}
546"#,
547 r#"
548fn f() {
549 let Some(x) = Some(()) else {//comment
550 println!("nope");
551 return
552 };
553}
554"#,
555 );
556 }
557
558 #[test]
559 fn top_level_or_pat() {
560 check_assist(
561 convert_match_to_let_else,
562 r#"
563enum E {
564 A(u32),
565 B(u32),
566 C,
567}
568
569fn foo() {
570 let e = E::A(0);
571 let _$0 = match e {
572 E::A(v) | E::B(v) => v,
573 _ => return,
574 };
575}
576 "#,
577 r#"
578enum E {
579 A(u32),
580 B(u32),
581 C,
582}
583
584fn foo() {
585 let e = E::A(0);
586 let (E::A(_) | E::B(_)) = e else { return };
587}
588 "#,
589 );
590 }
591}