1use ide_db::defs::{Definition, NameRefClass};
2use syntax::{
3 AstNode, SyntaxNode,
4 ast::{self, HasName, Name, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
5 syntax_editor::SyntaxEditor,
6};
7
8use crate::{
9 AssistId,
10 assist_context::{AssistContext, Assists},
11};
12
13pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
33 let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
34 let pat = let_stmt.pat()?;
35 if ctx.offset() > pat.syntax().text_range().end() {
36 return None;
37 }
38
39 let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None };
40 let initializer_expr = initializer.expr()?;
41
42 let (extracting_arm, diverging_arm) = find_arms(ctx, &initializer)?;
43 if extracting_arm.guard().is_some() {
44 cov_mark::hit!(extracting_arm_has_guard);
45 return None;
46 }
47
48 let diverging_arm_expr = match diverging_arm.expr()?.dedent(1.into()) {
49 ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => {
50 block.to_string()
51 }
52 other => format!("{{ {other} }}"),
53 };
54 let extracting_arm_pat = extracting_arm.pat()?;
55 let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
56
57 acc.add(
58 AssistId::refactor_rewrite("convert_match_to_let_else"),
59 "Convert match to let-else",
60 let_stmt.syntax().text_range(),
61 |builder| {
62 let extracting_arm_pat =
63 rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat);
64 builder.replace(
65 let_stmt.syntax().text_range(),
66 format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
67 )
68 },
69 )
70}
71
72fn find_arms(
74 ctx: &AssistContext<'_>,
75 match_expr: &ast::MatchExpr,
76) -> Option<(ast::MatchArm, ast::MatchArm)> {
77 let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
78 if arms.len() != 2 {
79 return None;
80 }
81
82 let mut extracting = None;
83 let mut diverging = None;
84 for arm in arms {
85 if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() {
86 diverging = Some(arm);
87 } else {
88 extracting = Some(arm);
89 }
90 }
91
92 match (extracting, diverging) {
93 (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
94 _ => {
95 cov_mark::hit!(non_diverging_match);
96 None
97 }
98 }
99}
100
101fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
103 match arm.expr()? {
104 ast::Expr::PathExpr(path) => {
105 let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
106 match NameRefClass::classify(&ctx.sema, &name_ref)? {
107 NameRefClass::Definition(Definition::Local(local), _) => {
108 let source =
109 local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
110 source.collect()
111 }
112 _ => None,
113 }
114 }
115 _ => {
116 cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
117 None
118 }
119 }
120}
121
122fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
124 let syntax = pat.syntax().clone_subtree();
125 let mut editor = SyntaxEditor::new(syntax.clone());
126 let make = SyntaxFactory::with_mappings();
127 let extracted = extracted
128 .iter()
129 .map(|e| e.syntax().text_range() - pat.syntax().text_range().start())
130 .map(|r| syntax.covering_element(r))
131 .collect::<Vec<_>>();
132 for extracted_syntax in extracted {
133 if let Some(record_pat_field) =
136 extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
137 {
138 if let Some(name_ref) = record_pat_field.field_name() {
139 editor.replace(
140 record_pat_field.syntax(),
141 make.record_pat_field(
142 make.name_ref(&name_ref.text()),
143 binding.clone_for_update(),
144 )
145 .syntax(),
146 );
147 }
148 } else {
149 editor.replace(extracted_syntax, binding.syntax().clone_for_update());
150 }
151 }
152 editor.add_mappings(make.finish_with_mappings());
153 let new_node = editor.finish().new_root().clone();
154 if let Some(pat) = ast::Pat::cast(new_node.clone()) {
155 pat.dedent(1.into()).syntax().clone()
156 } else {
157 new_node
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use crate::tests::{check_assist, check_assist_not_applicable};
164
165 use super::*;
166
167 #[test]
168 fn should_not_be_applicable_for_non_diverging_match() {
169 cov_mark::check!(non_diverging_match);
170 check_assist_not_applicable(
171 convert_match_to_let_else,
172 r#"
173//- minicore: option
174fn foo(opt: Option<()>) {
175 let val$0 = match opt {
176 Some(it) => it,
177 None => (),
178 };
179}
180"#,
181 );
182 }
183
184 #[test]
185 fn or_pattern_multiple_binding() {
186 check_assist(
187 convert_match_to_let_else,
188 r#"
189//- minicore: option
190enum Foo {
191 A(u32),
192 B(u32),
193 C(String),
194}
195
196fn foo(opt: Option<Foo>) -> Result<u32, ()> {
197 let va$0lue = match opt {
198 Some(Foo::A(it) | Foo::B(it)) => it,
199 _ => return Err(()),
200 };
201}
202 "#,
203 r#"
204enum Foo {
205 A(u32),
206 B(u32),
207 C(String),
208}
209
210fn foo(opt: Option<Foo>) -> Result<u32, ()> {
211 let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
212}
213 "#,
214 );
215 }
216
217 #[test]
218 fn indent_level() {
219 check_assist(
220 convert_match_to_let_else,
221 r#"
222//- minicore: option
223enum Foo {
224 A(u32),
225 B(u32),
226 C(String),
227}
228
229fn foo(opt: Option<Foo>) -> Result<u32, ()> {
230 let mut state = 2;
231 let va$0lue = match opt {
232 Some(
233 Foo::A(it)
234 | Foo::B(it)
235 ) => it,
236 _ => {
237 state = 3;
238 return Err(())
239 },
240 };
241}
242 "#,
243 r#"
244enum Foo {
245 A(u32),
246 B(u32),
247 C(String),
248}
249
250fn foo(opt: Option<Foo>) -> Result<u32, ()> {
251 let mut state = 2;
252 let Some(
253 Foo::A(value)
254 | Foo::B(value)
255 ) = opt else {
256 state = 3;
257 return Err(())
258 };
259}
260 "#,
261 );
262 }
263
264 #[test]
265 fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
266 cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
267 check_assist_not_applicable(
268 convert_match_to_let_else,
269 r#"
270//- minicore: option
271fn foo(opt: Option<i32>) {
272 let val$0 = match opt {
273 Some(it) => it + 1,
274 None => return,
275 };
276}
277"#,
278 );
279
280 check_assist_not_applicable(
281 convert_match_to_let_else,
282 r#"
283//- minicore: option
284fn foo(opt: Option<()>) {
285 let val$0 = match opt {
286 Some(it) => {
287 let _ = 1 + 1;
288 it
289 },
290 None => return,
291 };
292}
293"#,
294 );
295 }
296
297 #[test]
298 fn should_not_be_applicable_if_extracting_arm_has_guard() {
299 cov_mark::check!(extracting_arm_has_guard);
300 check_assist_not_applicable(
301 convert_match_to_let_else,
302 r#"
303//- minicore: option
304fn foo(opt: Option<()>) {
305 let val$0 = match opt {
306 Some(it) if 2 > 1 => it,
307 None => return,
308 };
309}
310"#,
311 );
312 }
313
314 #[test]
315 fn basic_pattern() {
316 check_assist(
317 convert_match_to_let_else,
318 r#"
319//- minicore: option
320fn foo(opt: Option<()>) {
321 let val$0 = match opt {
322 Some(it) => it,
323 None => return,
324 };
325}
326 "#,
327 r#"
328fn foo(opt: Option<()>) {
329 let Some(val) = opt else { return };
330}
331 "#,
332 );
333 }
334
335 #[test]
336 fn keeps_modifiers() {
337 check_assist(
338 convert_match_to_let_else,
339 r#"
340//- minicore: option
341fn foo(opt: Option<()>) {
342 let ref mut val$0 = match opt {
343 Some(it) => it,
344 None => return,
345 };
346}
347 "#,
348 r#"
349fn foo(opt: Option<()>) {
350 let Some(ref mut val) = opt else { return };
351}
352 "#,
353 );
354 }
355
356 #[test]
357 fn nested_pattern() {
358 check_assist(
359 convert_match_to_let_else,
360 r#"
361//- minicore: option, result
362fn foo(opt: Option<Result<()>>) {
363 let val$0 = match opt {
364 Some(Ok(it)) => it,
365 _ => return,
366 };
367}
368 "#,
369 r#"
370fn foo(opt: Option<Result<()>>) {
371 let Some(Ok(val)) = opt else { return };
372}
373 "#,
374 );
375 }
376
377 #[test]
378 fn works_with_any_diverging_block() {
379 check_assist(
380 convert_match_to_let_else,
381 r#"
382//- minicore: option
383fn foo(opt: Option<()>) {
384 loop {
385 let val$0 = match opt {
386 Some(it) => it,
387 None => break,
388 };
389 }
390}
391 "#,
392 r#"
393fn foo(opt: Option<()>) {
394 loop {
395 let Some(val) = opt else { break };
396 }
397}
398 "#,
399 );
400
401 check_assist(
402 convert_match_to_let_else,
403 r#"
404//- minicore: option
405fn foo(opt: Option<()>) {
406 loop {
407 let val$0 = match opt {
408 Some(it) => it,
409 None => continue,
410 };
411 }
412}
413 "#,
414 r#"
415fn foo(opt: Option<()>) {
416 loop {
417 let Some(val) = opt else { continue };
418 }
419}
420 "#,
421 );
422
423 check_assist(
424 convert_match_to_let_else,
425 r#"
426//- minicore: option
427fn panic() -> ! {}
428
429fn foo(opt: Option<()>) {
430 loop {
431 let val$0 = match opt {
432 Some(it) => it,
433 None => panic(),
434 };
435 }
436}
437 "#,
438 r#"
439fn panic() -> ! {}
440
441fn foo(opt: Option<()>) {
442 loop {
443 let Some(val) = opt else { panic() };
444 }
445}
446 "#,
447 );
448 }
449
450 #[test]
451 fn struct_pattern() {
452 check_assist(
453 convert_match_to_let_else,
454 r#"
455//- minicore: option
456struct Point {
457 x: i32,
458 y: i32,
459}
460
461fn foo(opt: Option<Point>) {
462 let val$0 = match opt {
463 Some(Point { x: 0, y }) => y,
464 _ => return,
465 };
466}
467 "#,
468 r#"
469struct Point {
470 x: i32,
471 y: i32,
472}
473
474fn foo(opt: Option<Point>) {
475 let Some(Point { x: 0, y: val }) = opt else { return };
476}
477 "#,
478 );
479 }
480
481 #[test]
482 fn renames_whole_binding() {
483 check_assist(
484 convert_match_to_let_else,
485 r#"
486//- minicore: option
487fn foo(opt: Option<i32>) -> Option<i32> {
488 let val$0 = match opt {
489 it @ Some(42) => it,
490 _ => return None,
491 };
492 val
493}
494 "#,
495 r#"
496fn foo(opt: Option<i32>) -> Option<i32> {
497 let val @ Some(42) = opt else { return None };
498 val
499}
500 "#,
501 );
502 }
503
504 #[test]
505 fn complex_pattern() {
506 check_assist(
507 convert_match_to_let_else,
508 r#"
509//- minicore: option
510fn f() {
511 let (x, y)$0 = match Some((0, 1)) {
512 Some(it) => it,
513 None => return,
514 };
515}
516"#,
517 r#"
518fn f() {
519 let Some((x, y)) = Some((0, 1)) else { return };
520}
521"#,
522 );
523 }
524
525 #[test]
526 fn diverging_block() {
527 check_assist(
528 convert_match_to_let_else,
529 r#"
530//- minicore: option
531fn f() {
532 let x$0 = match Some(()) {
533 Some(it) => it,
534 None => {//comment
535 println!("nope");
536 return
537 },
538 };
539}
540"#,
541 r#"
542fn f() {
543 let Some(x) = Some(()) else {//comment
544 println!("nope");
545 return
546 };
547}
548"#,
549 );
550 }
551}