ide_assists/handlers/
pull_assignment_up.rs

1use either::Either;
2use syntax::{
3    AstNode,
4    algo::find_node_at_range,
5    ast::{self, syntax_factory::SyntaxFactory},
6    syntax_editor::SyntaxEditor,
7};
8
9use crate::{
10    AssistId,
11    assist_context::{AssistContext, Assists},
12};
13
14// Assist: pull_assignment_up
15//
16// Extracts variable assignment to outside an if or match statement.
17//
18// ```
19// fn main() {
20//     let mut foo = 6;
21//
22//     if true {
23//         $0foo = 5;
24//     } else {
25//         foo = 4;
26//     }
27// }
28// ```
29// ->
30// ```
31// fn main() {
32//     let mut foo = 6;
33//
34//     foo = if true {
35//         5
36//     } else {
37//         4
38//     };
39// }
40// ```
41pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
42    let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
43
44    let op_kind = assign_expr.op_kind()?;
45    if op_kind != (ast::BinaryOp::Assignment { op: None }) {
46        cov_mark::hit!(test_cant_pull_non_assignments);
47        return None;
48    }
49
50    let mut collector = AssignmentsCollector {
51        sema: &ctx.sema,
52        common_lhs: assign_expr.lhs()?,
53        assignments: Vec::new(),
54    };
55
56    let node: Either<ast::IfExpr, ast::MatchExpr> = ctx.find_node_at_offset()?;
57    let tgt: ast::Expr = if let Either::Left(if_expr) = node {
58        let if_expr = std::iter::successors(Some(if_expr), |it| {
59            it.syntax().parent().and_then(ast::IfExpr::cast)
60        })
61        .last()?;
62        collector.collect_if(&if_expr)?;
63        if_expr.into()
64    } else if let Either::Right(match_expr) = node {
65        collector.collect_match(&match_expr)?;
66        match_expr.into()
67    } else {
68        return None;
69    };
70
71    if let Some(parent) = tgt.syntax().parent()
72        && matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT)
73    {
74        return None;
75    }
76    let target = tgt.syntax().text_range();
77
78    let edit_tgt = tgt.syntax().clone_subtree();
79    let assignments: Vec<_> = collector
80        .assignments
81        .into_iter()
82        .filter_map(|(stmt, rhs)| {
83            Some((
84                find_node_at_range::<ast::BinExpr>(
85                    &edit_tgt,
86                    stmt.syntax().text_range() - target.start(),
87                )?,
88                find_node_at_range::<ast::Expr>(
89                    &edit_tgt,
90                    rhs.syntax().text_range() - target.start(),
91                )?,
92            ))
93        })
94        .collect();
95
96    let mut editor = SyntaxEditor::new(edit_tgt);
97    for (stmt, rhs) in assignments {
98        let mut stmt = stmt.syntax().clone();
99        if let Some(parent) = stmt.parent()
100            && ast::ExprStmt::cast(parent.clone()).is_some()
101        {
102            stmt = parent.clone();
103        }
104        editor.replace(stmt, rhs.syntax());
105    }
106    let new_tgt_root = editor.finish().new_root().clone();
107    let new_tgt = ast::Expr::cast(new_tgt_root)?;
108    acc.add(
109        AssistId::refactor_extract("pull_assignment_up"),
110        "Pull assignment up",
111        target,
112        move |edit| {
113            let make = SyntaxFactory::with_mappings();
114            let mut editor = edit.make_editor(tgt.syntax());
115            let assign_expr = make.expr_assignment(collector.common_lhs, new_tgt.clone());
116            let assign_stmt = make.expr_stmt(assign_expr.into());
117
118            editor.replace(tgt.syntax(), assign_stmt.syntax());
119            editor.add_mappings(make.finish_with_mappings());
120            edit.add_file_edits(ctx.vfs_file_id(), editor);
121        },
122    )
123}
124
125struct AssignmentsCollector<'a> {
126    sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
127    common_lhs: ast::Expr,
128    assignments: Vec<(ast::BinExpr, ast::Expr)>,
129}
130
131impl AssignmentsCollector<'_> {
132    fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
133        for arm in match_expr.match_arm_list()?.arms() {
134            match arm.expr()? {
135                ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
136                ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
137                _ => return None,
138            }
139        }
140
141        Some(())
142    }
143    fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
144        let then_branch = if_expr.then_branch()?;
145        self.collect_block(&then_branch)?;
146
147        match if_expr.else_branch()? {
148            ast::ElseBranch::Block(block) => self.collect_block(&block),
149            ast::ElseBranch::IfExpr(expr) => {
150                cov_mark::hit!(test_pull_assignment_up_chained_if);
151                self.collect_if(&expr)
152            }
153        }
154    }
155    fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
156        let last_expr = block.tail_expr().or_else(|| match block.statements().last()? {
157            ast::Stmt::ExprStmt(stmt) => stmt.expr(),
158            ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None,
159        })?;
160
161        if let ast::Expr::BinExpr(expr) = last_expr {
162            return self.collect_expr(&expr);
163        }
164
165        None
166    }
167
168    fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
169        if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None })
170            && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
171        {
172            self.assignments.push((expr.clone(), expr.rhs()?));
173            return Some(());
174        }
175        None
176    }
177}
178
179fn is_equivalent(
180    sema: &hir::Semantics<'_, ide_db::RootDatabase>,
181    expr0: &ast::Expr,
182    expr1: &ast::Expr,
183) -> bool {
184    match (expr0, expr1) {
185        (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
186            cov_mark::hit!(test_pull_assignment_up_field_assignment);
187            sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
188        }
189        (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
190            let path0 = path0.path();
191            let path1 = path1.path();
192            if let (Some(path0), Some(path1)) = (path0, path1) {
193                sema.resolve_path(&path0) == sema.resolve_path(&path1)
194            } else {
195                false
196            }
197        }
198        (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
199            if prefix0.op_kind() == Some(ast::UnaryOp::Deref)
200                && prefix1.op_kind() == Some(ast::UnaryOp::Deref) =>
201        {
202            cov_mark::hit!(test_pull_assignment_up_deref);
203            if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
204                is_equivalent(sema, &prefix0, &prefix1)
205            } else {
206                false
207            }
208        }
209        _ => false,
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    use crate::tests::{check_assist, check_assist_not_applicable};
218
219    #[test]
220    fn test_pull_assignment_up_if() {
221        check_assist(
222            pull_assignment_up,
223            r#"
224fn foo() {
225    let mut a = 1;
226
227    if true {
228        $0a = 2;
229    } else {
230        a = 3;
231    }
232}"#,
233            r#"
234fn foo() {
235    let mut a = 1;
236
237    a = if true {
238        2
239    } else {
240        3
241    };
242}"#,
243        );
244    }
245
246    #[test]
247    fn test_pull_assignment_up_inner_if() {
248        check_assist(
249            pull_assignment_up,
250            r#"
251fn foo() {
252    let mut a = 1;
253
254    if true {
255        a = 2;
256    } else if true {
257        $0a = 3;
258    } else {
259        a = 4;
260    }
261}"#,
262            r#"
263fn foo() {
264    let mut a = 1;
265
266    a = if true {
267        2
268    } else if true {
269        3
270    } else {
271        4
272    };
273}"#,
274        );
275    }
276
277    #[test]
278    fn test_pull_assignment_up_match() {
279        check_assist(
280            pull_assignment_up,
281            r#"
282fn foo() {
283    let mut a = 1;
284
285    match 1 {
286        1 => {
287            $0a = 2;
288        },
289        2 => {
290            a = 3;
291        },
292        3 => {
293            a = 4;
294        }
295    }
296}"#,
297            r#"
298fn foo() {
299    let mut a = 1;
300
301    a = match 1 {
302        1 => {
303            2
304        },
305        2 => {
306            3
307        },
308        3 => {
309            4
310        }
311    };
312}"#,
313        );
314    }
315
316    #[test]
317    fn test_pull_assignment_up_match_in_if_expr() {
318        check_assist(
319            pull_assignment_up,
320            r#"
321fn foo() {
322    let x;
323    if true {
324        match true {
325            true => $0x = 2,
326            false => x = 3,
327        }
328    }
329}"#,
330            r#"
331fn foo() {
332    let x;
333    if true {
334        x = match true {
335            true => 2,
336            false => 3,
337        };
338    }
339}"#,
340        );
341    }
342
343    #[test]
344    fn test_pull_assignment_up_assignment_expressions() {
345        check_assist(
346            pull_assignment_up,
347            r#"
348fn foo() {
349    let mut a = 1;
350
351    match 1 {
352        1 => { $0a = 2; },
353        2 => a = 3,
354        3 => {
355            a = 4
356        }
357    }
358}"#,
359            r#"
360fn foo() {
361    let mut a = 1;
362
363    a = match 1 {
364        1 => { 2 },
365        2 => 3,
366        3 => {
367            4
368        }
369    };
370}"#,
371        );
372    }
373
374    #[test]
375    fn test_pull_assignment_up_not_last_not_applicable() {
376        check_assist_not_applicable(
377            pull_assignment_up,
378            r#"
379fn foo() {
380    let mut a = 1;
381
382    if true {
383        $0a = 2;
384        b = a;
385    } else {
386        a = 3;
387    }
388}"#,
389        )
390    }
391
392    #[test]
393    fn test_pull_assignment_up_chained_if() {
394        cov_mark::check!(test_pull_assignment_up_chained_if);
395        check_assist(
396            pull_assignment_up,
397            r#"
398fn foo() {
399    let mut a = 1;
400
401    if true {
402        $0a = 2;
403    } else if false {
404        a = 3;
405    } else {
406        a = 4;
407    }
408}"#,
409            r#"
410fn foo() {
411    let mut a = 1;
412
413    a = if true {
414        2
415    } else if false {
416        3
417    } else {
418        4
419    };
420}"#,
421        );
422    }
423
424    #[test]
425    fn test_pull_assignment_up_retains_stmts() {
426        check_assist(
427            pull_assignment_up,
428            r#"
429fn foo() {
430    let mut a = 1;
431
432    if true {
433        let b = 2;
434        $0a = 2;
435    } else {
436        let b = 3;
437        a = 3;
438    }
439}"#,
440            r#"
441fn foo() {
442    let mut a = 1;
443
444    a = if true {
445        let b = 2;
446        2
447    } else {
448        let b = 3;
449        3
450    };
451}"#,
452        )
453    }
454
455    #[test]
456    fn pull_assignment_up_let_stmt_not_applicable() {
457        check_assist_not_applicable(
458            pull_assignment_up,
459            r#"
460fn foo() {
461    let mut a = 1;
462
463    let b = if true {
464        $0a = 2
465    } else {
466        a = 3
467    };
468}"#,
469        )
470    }
471
472    #[test]
473    fn pull_assignment_up_if_missing_assignment_not_applicable() {
474        check_assist_not_applicable(
475            pull_assignment_up,
476            r#"
477fn foo() {
478    let mut a = 1;
479
480    if true {
481        $0a = 2;
482    } else {}
483}"#,
484        )
485    }
486
487    #[test]
488    fn pull_assignment_up_match_missing_assignment_not_applicable() {
489        check_assist_not_applicable(
490            pull_assignment_up,
491            r#"
492fn foo() {
493    let mut a = 1;
494
495    match 1 {
496        1 => {
497            $0a = 2;
498        },
499        2 => {
500            a = 3;
501        },
502        3 => {},
503    }
504}"#,
505        )
506    }
507
508    #[test]
509    fn test_pull_assignment_up_field_assignment() {
510        cov_mark::check!(test_pull_assignment_up_field_assignment);
511        check_assist(
512            pull_assignment_up,
513            r#"
514struct A(usize);
515
516fn foo() {
517    let mut a = A(1);
518
519    if true {
520        $0a.0 = 2;
521    } else {
522        a.0 = 3;
523    }
524}"#,
525            r#"
526struct A(usize);
527
528fn foo() {
529    let mut a = A(1);
530
531    a.0 = if true {
532        2
533    } else {
534        3
535    };
536}"#,
537        )
538    }
539
540    #[test]
541    fn test_pull_assignment_up_deref() {
542        cov_mark::check!(test_pull_assignment_up_deref);
543        check_assist(
544            pull_assignment_up,
545            r#"
546fn foo() {
547    let mut a = 1;
548    let b = &mut a;
549
550    if true {
551        $0*b = 2;
552    } else {
553        *b = 3;
554    }
555}
556"#,
557            r#"
558fn foo() {
559    let mut a = 1;
560    let b = &mut a;
561
562    *b = if true {
563        2
564    } else {
565        3
566    };
567}
568"#,
569        )
570    }
571
572    #[test]
573    fn test_cant_pull_non_assignments() {
574        cov_mark::check!(test_cant_pull_non_assignments);
575        check_assist_not_applicable(
576            pull_assignment_up,
577            r#"
578fn foo() {
579    let mut a = 1;
580    let b = &mut a;
581
582    if true {
583        $0*b + 2;
584    } else {
585        *b + 3;
586    }
587}
588"#,
589        )
590    }
591}