1use either::Either;
2use syntax::{
3 AstNode,
4 algo::find_node_at_range,
5 ast::{self, syntax_factory::SyntaxFactory},
6 syntax_editor::SyntaxEditor,
7};
8
9use crate::{
10 AssistId,
11 assist_context::{AssistContext, Assists},
12};
13
14pub(crate) fn pull_assignment_up(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
42 let assign_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
43
44 let op_kind = assign_expr.op_kind()?;
45 if op_kind != (ast::BinaryOp::Assignment { op: None }) {
46 cov_mark::hit!(test_cant_pull_non_assignments);
47 return None;
48 }
49
50 let mut collector = AssignmentsCollector {
51 sema: &ctx.sema,
52 common_lhs: assign_expr.lhs()?,
53 assignments: Vec::new(),
54 };
55
56 let node: Either<ast::IfExpr, ast::MatchExpr> = ctx.find_node_at_offset()?;
57 let tgt: ast::Expr = if let Either::Left(if_expr) = node {
58 let if_expr = std::iter::successors(Some(if_expr), |it| {
59 it.syntax().parent().and_then(ast::IfExpr::cast)
60 })
61 .last()?;
62 collector.collect_if(&if_expr)?;
63 if_expr.into()
64 } else if let Either::Right(match_expr) = node {
65 collector.collect_match(&match_expr)?;
66 match_expr.into()
67 } else {
68 return None;
69 };
70
71 if let Some(parent) = tgt.syntax().parent()
72 && matches!(parent.kind(), syntax::SyntaxKind::BIN_EXPR | syntax::SyntaxKind::LET_STMT)
73 {
74 return None;
75 }
76 let target = tgt.syntax().text_range();
77
78 let edit_tgt = tgt.syntax().clone_subtree();
79 let assignments: Vec<_> = collector
80 .assignments
81 .into_iter()
82 .filter_map(|(stmt, rhs)| {
83 Some((
84 find_node_at_range::<ast::BinExpr>(
85 &edit_tgt,
86 stmt.syntax().text_range() - target.start(),
87 )?,
88 find_node_at_range::<ast::Expr>(
89 &edit_tgt,
90 rhs.syntax().text_range() - target.start(),
91 )?,
92 ))
93 })
94 .collect();
95
96 let mut editor = SyntaxEditor::new(edit_tgt);
97 for (stmt, rhs) in assignments {
98 let mut stmt = stmt.syntax().clone();
99 if let Some(parent) = stmt.parent()
100 && ast::ExprStmt::cast(parent.clone()).is_some()
101 {
102 stmt = parent.clone();
103 }
104 editor.replace(stmt, rhs.syntax());
105 }
106 let new_tgt_root = editor.finish().new_root().clone();
107 let new_tgt = ast::Expr::cast(new_tgt_root)?;
108 acc.add(
109 AssistId::refactor_extract("pull_assignment_up"),
110 "Pull assignment up",
111 target,
112 move |edit| {
113 let make = SyntaxFactory::with_mappings();
114 let mut editor = edit.make_editor(tgt.syntax());
115 let assign_expr = make.expr_assignment(collector.common_lhs, new_tgt.clone());
116 let assign_stmt = make.expr_stmt(assign_expr.into());
117
118 editor.replace(tgt.syntax(), assign_stmt.syntax());
119 editor.add_mappings(make.finish_with_mappings());
120 edit.add_file_edits(ctx.vfs_file_id(), editor);
121 },
122 )
123}
124
125struct AssignmentsCollector<'a> {
126 sema: &'a hir::Semantics<'a, ide_db::RootDatabase>,
127 common_lhs: ast::Expr,
128 assignments: Vec<(ast::BinExpr, ast::Expr)>,
129}
130
131impl AssignmentsCollector<'_> {
132 fn collect_match(&mut self, match_expr: &ast::MatchExpr) -> Option<()> {
133 for arm in match_expr.match_arm_list()?.arms() {
134 match arm.expr()? {
135 ast::Expr::BlockExpr(block) => self.collect_block(&block)?,
136 ast::Expr::BinExpr(expr) => self.collect_expr(&expr)?,
137 _ => return None,
138 }
139 }
140
141 Some(())
142 }
143 fn collect_if(&mut self, if_expr: &ast::IfExpr) -> Option<()> {
144 let then_branch = if_expr.then_branch()?;
145 self.collect_block(&then_branch)?;
146
147 match if_expr.else_branch()? {
148 ast::ElseBranch::Block(block) => self.collect_block(&block),
149 ast::ElseBranch::IfExpr(expr) => {
150 cov_mark::hit!(test_pull_assignment_up_chained_if);
151 self.collect_if(&expr)
152 }
153 }
154 }
155 fn collect_block(&mut self, block: &ast::BlockExpr) -> Option<()> {
156 let last_expr = block.tail_expr().or_else(|| match block.statements().last()? {
157 ast::Stmt::ExprStmt(stmt) => stmt.expr(),
158 ast::Stmt::Item(_) | ast::Stmt::LetStmt(_) => None,
159 })?;
160
161 if let ast::Expr::BinExpr(expr) = last_expr {
162 return self.collect_expr(&expr);
163 }
164
165 None
166 }
167
168 fn collect_expr(&mut self, expr: &ast::BinExpr) -> Option<()> {
169 if expr.op_kind()? == (ast::BinaryOp::Assignment { op: None })
170 && is_equivalent(self.sema, &expr.lhs()?, &self.common_lhs)
171 {
172 self.assignments.push((expr.clone(), expr.rhs()?));
173 return Some(());
174 }
175 None
176 }
177}
178
179fn is_equivalent(
180 sema: &hir::Semantics<'_, ide_db::RootDatabase>,
181 expr0: &ast::Expr,
182 expr1: &ast::Expr,
183) -> bool {
184 match (expr0, expr1) {
185 (ast::Expr::FieldExpr(field_expr0), ast::Expr::FieldExpr(field_expr1)) => {
186 cov_mark::hit!(test_pull_assignment_up_field_assignment);
187 sema.resolve_field(field_expr0) == sema.resolve_field(field_expr1)
188 }
189 (ast::Expr::PathExpr(path0), ast::Expr::PathExpr(path1)) => {
190 let path0 = path0.path();
191 let path1 = path1.path();
192 if let (Some(path0), Some(path1)) = (path0, path1) {
193 sema.resolve_path(&path0) == sema.resolve_path(&path1)
194 } else {
195 false
196 }
197 }
198 (ast::Expr::PrefixExpr(prefix0), ast::Expr::PrefixExpr(prefix1))
199 if prefix0.op_kind() == Some(ast::UnaryOp::Deref)
200 && prefix1.op_kind() == Some(ast::UnaryOp::Deref) =>
201 {
202 cov_mark::hit!(test_pull_assignment_up_deref);
203 if let (Some(prefix0), Some(prefix1)) = (prefix0.expr(), prefix1.expr()) {
204 is_equivalent(sema, &prefix0, &prefix1)
205 } else {
206 false
207 }
208 }
209 _ => false,
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 use crate::tests::{check_assist, check_assist_not_applicable};
218
219 #[test]
220 fn test_pull_assignment_up_if() {
221 check_assist(
222 pull_assignment_up,
223 r#"
224fn foo() {
225 let mut a = 1;
226
227 if true {
228 $0a = 2;
229 } else {
230 a = 3;
231 }
232}"#,
233 r#"
234fn foo() {
235 let mut a = 1;
236
237 a = if true {
238 2
239 } else {
240 3
241 };
242}"#,
243 );
244 }
245
246 #[test]
247 fn test_pull_assignment_up_inner_if() {
248 check_assist(
249 pull_assignment_up,
250 r#"
251fn foo() {
252 let mut a = 1;
253
254 if true {
255 a = 2;
256 } else if true {
257 $0a = 3;
258 } else {
259 a = 4;
260 }
261}"#,
262 r#"
263fn foo() {
264 let mut a = 1;
265
266 a = if true {
267 2
268 } else if true {
269 3
270 } else {
271 4
272 };
273}"#,
274 );
275 }
276
277 #[test]
278 fn test_pull_assignment_up_match() {
279 check_assist(
280 pull_assignment_up,
281 r#"
282fn foo() {
283 let mut a = 1;
284
285 match 1 {
286 1 => {
287 $0a = 2;
288 },
289 2 => {
290 a = 3;
291 },
292 3 => {
293 a = 4;
294 }
295 }
296}"#,
297 r#"
298fn foo() {
299 let mut a = 1;
300
301 a = match 1 {
302 1 => {
303 2
304 },
305 2 => {
306 3
307 },
308 3 => {
309 4
310 }
311 };
312}"#,
313 );
314 }
315
316 #[test]
317 fn test_pull_assignment_up_match_in_if_expr() {
318 check_assist(
319 pull_assignment_up,
320 r#"
321fn foo() {
322 let x;
323 if true {
324 match true {
325 true => $0x = 2,
326 false => x = 3,
327 }
328 }
329}"#,
330 r#"
331fn foo() {
332 let x;
333 if true {
334 x = match true {
335 true => 2,
336 false => 3,
337 };
338 }
339}"#,
340 );
341 }
342
343 #[test]
344 fn test_pull_assignment_up_assignment_expressions() {
345 check_assist(
346 pull_assignment_up,
347 r#"
348fn foo() {
349 let mut a = 1;
350
351 match 1 {
352 1 => { $0a = 2; },
353 2 => a = 3,
354 3 => {
355 a = 4
356 }
357 }
358}"#,
359 r#"
360fn foo() {
361 let mut a = 1;
362
363 a = match 1 {
364 1 => { 2 },
365 2 => 3,
366 3 => {
367 4
368 }
369 };
370}"#,
371 );
372 }
373
374 #[test]
375 fn test_pull_assignment_up_not_last_not_applicable() {
376 check_assist_not_applicable(
377 pull_assignment_up,
378 r#"
379fn foo() {
380 let mut a = 1;
381
382 if true {
383 $0a = 2;
384 b = a;
385 } else {
386 a = 3;
387 }
388}"#,
389 )
390 }
391
392 #[test]
393 fn test_pull_assignment_up_chained_if() {
394 cov_mark::check!(test_pull_assignment_up_chained_if);
395 check_assist(
396 pull_assignment_up,
397 r#"
398fn foo() {
399 let mut a = 1;
400
401 if true {
402 $0a = 2;
403 } else if false {
404 a = 3;
405 } else {
406 a = 4;
407 }
408}"#,
409 r#"
410fn foo() {
411 let mut a = 1;
412
413 a = if true {
414 2
415 } else if false {
416 3
417 } else {
418 4
419 };
420}"#,
421 );
422 }
423
424 #[test]
425 fn test_pull_assignment_up_retains_stmts() {
426 check_assist(
427 pull_assignment_up,
428 r#"
429fn foo() {
430 let mut a = 1;
431
432 if true {
433 let b = 2;
434 $0a = 2;
435 } else {
436 let b = 3;
437 a = 3;
438 }
439}"#,
440 r#"
441fn foo() {
442 let mut a = 1;
443
444 a = if true {
445 let b = 2;
446 2
447 } else {
448 let b = 3;
449 3
450 };
451}"#,
452 )
453 }
454
455 #[test]
456 fn pull_assignment_up_let_stmt_not_applicable() {
457 check_assist_not_applicable(
458 pull_assignment_up,
459 r#"
460fn foo() {
461 let mut a = 1;
462
463 let b = if true {
464 $0a = 2
465 } else {
466 a = 3
467 };
468}"#,
469 )
470 }
471
472 #[test]
473 fn pull_assignment_up_if_missing_assignment_not_applicable() {
474 check_assist_not_applicable(
475 pull_assignment_up,
476 r#"
477fn foo() {
478 let mut a = 1;
479
480 if true {
481 $0a = 2;
482 } else {}
483}"#,
484 )
485 }
486
487 #[test]
488 fn pull_assignment_up_match_missing_assignment_not_applicable() {
489 check_assist_not_applicable(
490 pull_assignment_up,
491 r#"
492fn foo() {
493 let mut a = 1;
494
495 match 1 {
496 1 => {
497 $0a = 2;
498 },
499 2 => {
500 a = 3;
501 },
502 3 => {},
503 }
504}"#,
505 )
506 }
507
508 #[test]
509 fn test_pull_assignment_up_field_assignment() {
510 cov_mark::check!(test_pull_assignment_up_field_assignment);
511 check_assist(
512 pull_assignment_up,
513 r#"
514struct A(usize);
515
516fn foo() {
517 let mut a = A(1);
518
519 if true {
520 $0a.0 = 2;
521 } else {
522 a.0 = 3;
523 }
524}"#,
525 r#"
526struct A(usize);
527
528fn foo() {
529 let mut a = A(1);
530
531 a.0 = if true {
532 2
533 } else {
534 3
535 };
536}"#,
537 )
538 }
539
540 #[test]
541 fn test_pull_assignment_up_deref() {
542 cov_mark::check!(test_pull_assignment_up_deref);
543 check_assist(
544 pull_assignment_up,
545 r#"
546fn foo() {
547 let mut a = 1;
548 let b = &mut a;
549
550 if true {
551 $0*b = 2;
552 } else {
553 *b = 3;
554 }
555}
556"#,
557 r#"
558fn foo() {
559 let mut a = 1;
560 let b = &mut a;
561
562 *b = if true {
563 2
564 } else {
565 3
566 };
567}
568"#,
569 )
570 }
571
572 #[test]
573 fn test_cant_pull_non_assignments() {
574 cov_mark::check!(test_cant_pull_non_assignments);
575 check_assist_not_applicable(
576 pull_assignment_up,
577 r#"
578fn foo() {
579 let mut a = 1;
580 let b = &mut a;
581
582 if true {
583 $0*b + 2;
584 } else {
585 *b + 3;
586 }
587}
588"#,
589 )
590 }
591}