Skip to main content

ide_assists/handlers/
pull_assignment_up.rs

1use either::Either;
2use syntax::{AstNode, algo::find_node_at_range, ast, syntax_editor::SyntaxEditor};
3
4use crate::{
5    AssistId,
6    assist_context::{AssistContext, Assists},
7};
8
9// Assist: pull_assignment_up
10//
11// Extracts variable assignment to outside an if or match statement.
12//
13// ```
14// fn main() {
15//     let mut foo = 6;
16//
17//     if true {
18//         $0foo = 5;
19//     } else {
20//         foo = 4;
21//     }
22// }
23// ```
24// ->
25// ```
26// fn main() {
27//     let mut foo = 6;
28//
29//     foo = if true {
30//         5
31//     } else {
32//         4
33//     };
34// }
35// ```
36pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
37    let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
38
39    let op_kind = assign_expr.op_kind()?;
40    if op_kind != (ast::BinaryOp::Assignment { op: None }) {
41        cov_mark::hit!(test_cant_pull_non_assignments);
42        return None;
43    }
44
45    let mut collector = AssignmentsCollector {
46        sema: &ctx.sema,
47        common_lhs: assign_expr.lhs()?,
48        assignments: Vec::new(),
49    };
50
51    let node: Either<ast::IfExpr, ast::MatchExpr> = ctx.find_node_at_offset()?;
52    let tgt: ast::Expr = if let Either::Left(if_expr) = node {
53        let if_expr = std::iter::successors(Some(if_expr), |it| {
54            it.syntax().parent().and_then(ast::IfExpr::cast)
55        })
56        .last()?;
57        collector.collect_if(&if_expr)?;
58        if_expr.into()
59    } else if let Either::Right(match_expr) = node {
60        collector.collect_match(&match_expr)?;
61        match_expr.into()
62    } else {
63        return None;
64    };
65
66    if let Some(parent) = tgt.syntax().parent()
67        && matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT)
68    {
69        return None;
70    }
71    let target = tgt.syntax().text_range();
72
73    let (editor, edit_tgt) = SyntaxEditor::new(tgt.syntax().clone());
74    let assignments: Vec<_> = collector
75        .assignments
76        .into_iter()
77        .filter_map(|(stmt, rhs)| {
78            Some((
79                find_node_at_range::<ast::BinExpr>(
80                    &edit_tgt,
81                    stmt.syntax().text_range() - target.start(),
82                )?,
83                find_node_at_range::<ast::Expr>(
84                    &edit_tgt,
85                    rhs.syntax().text_range() - target.start(),
86                )?,
87            ))
88        })
89        .collect();
90
91    for (stmt, rhs) in assignments {
92        let mut stmt = stmt.syntax().clone();
93        if let Some(parent) = stmt.parent()
94            && ast::ExprStmt::cast(parent.clone()).is_some()
95        {
96            stmt = parent.clone();
97        }
98        editor.replace(stmt, rhs.syntax());
99    }
100    let new_tgt_root = editor.finish().new_root().clone();
101    let new_tgt = ast::Expr::cast(new_tgt_root)?;
102    acc.add(
103        AssistId::refactor_extract("pull_assignment_up"),
104        "Pull assignment up",
105        target,
106        move |edit| {
107            let editor = edit.make_editor(tgt.syntax());
108            let make = editor.make();
109            let assign_expr = make.expr_assignment(collector.common_lhs, new_tgt.clone());
110            let assign_stmt = make.expr_stmt(assign_expr.into());
111
112            editor.replace(tgt.syntax(), assign_stmt.syntax());
113            edit.add_file_edits(ctx.vfs_file_id(), editor);
114        },
115    )
116}
117
118struct AssignmentsCollector<'a, 'db> {
119    sema: &'a hir::Semantics<'db, ide_db::RootDatabase>,
120    common_lhs: ast::Expr,
121    assignments: Vec<(ast::BinExpr, ast::Expr)>,
122}
123
124impl AssignmentsCollector<'_, '_> {
125    fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
126        for arm in match_expr.match_arm_list()?.arms() {
127            match arm.expr()? {
128                ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
129                ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
130                _ => return None,
131            }
132        }
133
134        Some(())
135    }
136    fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
137        let then_branch = if_expr.then_branch()?;
138        self.collect_block(&then_branch)?;
139
140        match if_expr.else_branch()? {
141            ast::ElseBranch::Block(block) => self.collect_block(&block),
142            ast::ElseBranch::IfExpr(expr) => {
143                cov_mark::hit!(test_pull_assignment_up_chained_if);
144                self.collect_if(&expr)
145            }
146        }
147    }
148    fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
149        let last_expr = block.tail_expr().or_else(|| match block.statements().last()? {
150            ast::Stmt::ExprStmt(stmt) => stmt.expr(),
151            ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None,
152        })?;
153
154        if let ast::Expr::BinExpr(expr) = last_expr {
155            return self.collect_expr(&expr);
156        }
157
158        None
159    }
160
161    fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
162        if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None })
163            && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
164        {
165            self.assignments.push((expr.clone(), expr.rhs()?));
166            return Some(());
167        }
168        None
169    }
170}
171
172fn is_equivalent(
173    sema: &hir::Semantics<'_, ide_db::RootDatabase>,
174    expr0: &ast::Expr,
175    expr1: &ast::Expr,
176) -> bool {
177    match (expr0, expr1) {
178        (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
179            cov_mark::hit!(test_pull_assignment_up_field_assignment);
180            sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
181        }
182        (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
183            let path0 = path0.path();
184            let path1 = path1.path();
185            if let (Some(path0), Some(path1)) = (path0, path1) {
186                sema.resolve_path(&path0) == sema.resolve_path(&path1)
187            } else {
188                false
189            }
190        }
191        (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
192            if prefix0.op_kind() == Some(ast::UnaryOp::Deref)
193                && prefix1.op_kind() == Some(ast::UnaryOp::Deref) =>
194        {
195            cov_mark::hit!(test_pull_assignment_up_deref);
196            if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
197                is_equivalent(sema, &prefix0, &prefix1)
198            } else {
199                false
200            }
201        }
202        _ => false,
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    use crate::tests::{check_assist, check_assist_not_applicable};
211
212    #[test]
213    fn test_pull_assignment_up_if() {
214        check_assist(
215            pull_assignment_up,
216            r#"
217fn foo() {
218    let mut a = 1;
219
220    if true {
221        $0a = 2;
222    } else {
223        a = 3;
224    }
225}"#,
226            r#"
227fn foo() {
228    let mut a = 1;
229
230    a = if true {
231        2
232    } else {
233        3
234    };
235}"#,
236        );
237    }
238
239    #[test]
240    fn test_pull_assignment_up_inner_if() {
241        check_assist(
242            pull_assignment_up,
243            r#"
244fn foo() {
245    let mut a = 1;
246
247    if true {
248        a = 2;
249    } else if true {
250        $0a = 3;
251    } else {
252        a = 4;
253    }
254}"#,
255            r#"
256fn foo() {
257    let mut a = 1;
258
259    a = if true {
260        2
261    } else if true {
262        3
263    } else {
264        4
265    };
266}"#,
267        );
268    }
269
270    #[test]
271    fn test_pull_assignment_up_match() {
272        check_assist(
273            pull_assignment_up,
274            r#"
275fn foo() {
276    let mut a = 1;
277
278    match 1 {
279        1 => {
280            $0a = 2;
281        },
282        2 => {
283            a = 3;
284        },
285        3 => {
286            a = 4;
287        }
288    }
289}"#,
290            r#"
291fn foo() {
292    let mut a = 1;
293
294    a = match 1 {
295        1 => {
296            2
297        },
298        2 => {
299            3
300        },
301        3 => {
302            4
303        }
304    };
305}"#,
306        );
307    }
308
309    #[test]
310    fn test_pull_assignment_up_match_in_if_expr() {
311        check_assist(
312            pull_assignment_up,
313            r#"
314fn foo() {
315    let x;
316    if true {
317        match true {
318            true => $0x = 2,
319            false => x = 3,
320        }
321    }
322}"#,
323            r#"
324fn foo() {
325    let x;
326    if true {
327        x = match true {
328            true => 2,
329            false => 3,
330        };
331    }
332}"#,
333        );
334    }
335
336    #[test]
337    fn test_pull_assignment_up_assignment_expressions() {
338        check_assist(
339            pull_assignment_up,
340            r#"
341fn foo() {
342    let mut a = 1;
343
344    match 1 {
345        1 => { $0a = 2; },
346        2 => a = 3,
347        3 => {
348            a = 4
349        }
350    }
351}"#,
352            r#"
353fn foo() {
354    let mut a = 1;
355
356    a = match 1 {
357        1 => { 2 },
358        2 => 3,
359        3 => {
360            4
361        }
362    };
363}"#,
364        );
365    }
366
367    #[test]
368    fn test_pull_assignment_up_not_last_not_applicable() {
369        check_assist_not_applicable(
370            pull_assignment_up,
371            r#"
372fn foo() {
373    let mut a = 1;
374
375    if true {
376        $0a = 2;
377        b = a;
378    } else {
379        a = 3;
380    }
381}"#,
382        )
383    }
384
385    #[test]
386    fn test_pull_assignment_up_chained_if() {
387        cov_mark::check!(test_pull_assignment_up_chained_if);
388        check_assist(
389            pull_assignment_up,
390            r#"
391fn foo() {
392    let mut a = 1;
393
394    if true {
395        $0a = 2;
396    } else if false {
397        a = 3;
398    } else {
399        a = 4;
400    }
401}"#,
402            r#"
403fn foo() {
404    let mut a = 1;
405
406    a = if true {
407        2
408    } else if false {
409        3
410    } else {
411        4
412    };
413}"#,
414        );
415    }
416
417    #[test]
418    fn test_pull_assignment_up_retains_stmts() {
419        check_assist(
420            pull_assignment_up,
421            r#"
422fn foo() {
423    let mut a = 1;
424
425    if true {
426        let b = 2;
427        $0a = 2;
428    } else {
429        let b = 3;
430        a = 3;
431    }
432}"#,
433            r#"
434fn foo() {
435    let mut a = 1;
436
437    a = if true {
438        let b = 2;
439        2
440    } else {
441        let b = 3;
442        3
443    };
444}"#,
445        )
446    }
447
448    #[test]
449    fn pull_assignment_up_let_stmt_not_applicable() {
450        check_assist_not_applicable(
451            pull_assignment_up,
452            r#"
453fn foo() {
454    let mut a = 1;
455
456    let b = if true {
457        $0a = 2
458    } else {
459        a = 3
460    };
461}"#,
462        )
463    }
464
465    #[test]
466    fn pull_assignment_up_if_missing_assignment_not_applicable() {
467        check_assist_not_applicable(
468            pull_assignment_up,
469            r#"
470fn foo() {
471    let mut a = 1;
472
473    if true {
474        $0a = 2;
475    } else {}
476}"#,
477        )
478    }
479
480    #[test]
481    fn pull_assignment_up_match_missing_assignment_not_applicable() {
482        check_assist_not_applicable(
483            pull_assignment_up,
484            r#"
485fn foo() {
486    let mut a = 1;
487
488    match 1 {
489        1 => {
490            $0a = 2;
491        },
492        2 => {
493            a = 3;
494        },
495        3 => {},
496    }
497}"#,
498        )
499    }
500
501    #[test]
502    fn test_pull_assignment_up_field_assignment() {
503        cov_mark::check!(test_pull_assignment_up_field_assignment);
504        check_assist(
505            pull_assignment_up,
506            r#"
507struct A(usize);
508
509fn foo() {
510    let mut a = A(1);
511
512    if true {
513        $0a.0 = 2;
514    } else {
515        a.0 = 3;
516    }
517}"#,
518            r#"
519struct A(usize);
520
521fn foo() {
522    let mut a = A(1);
523
524    a.0 = if true {
525        2
526    } else {
527        3
528    };
529}"#,
530        )
531    }
532
533    #[test]
534    fn test_pull_assignment_up_deref() {
535        cov_mark::check!(test_pull_assignment_up_deref);
536        check_assist(
537            pull_assignment_up,
538            r#"
539fn foo() {
540    let mut a = 1;
541    let b = &mut a;
542
543    if true {
544        $0*b = 2;
545    } else {
546        *b = 3;
547    }
548}
549"#,
550            r#"
551fn foo() {
552    let mut a = 1;
553    let b = &mut a;
554
555    *b = if true {
556        2
557    } else {
558        3
559    };
560}
561"#,
562        )
563    }
564
565    #[test]
566    fn test_cant_pull_non_assignments() {
567        cov_mark::check!(test_cant_pull_non_assignments);
568        check_assist_not_applicable(
569            pull_assignment_up,
570            r#"
571fn foo() {
572    let mut a = 1;
573    let b = &mut a;
574
575    if true {
576        $0*b + 2;
577    } else {
578        *b + 3;
579    }
580}
581"#,
582        )
583    }
584}