Skip to main content

ide_assists/handlers/
convert_range_for_to_while.rs

1use ide_db::assists::AssistId;
2use itertools::Itertools;
3use syntax::{
4    AstNode, SyntaxElement,
5    SyntaxKind::WHITESPACE,
6    T,
7    algo::previous_non_trivia_token,
8    ast::{
9        self, HasArgList, HasLoopBody, HasName, RangeItem, edit::AstNodeEdit,
10        syntax_factory::SyntaxFactory,
11    },
12    syntax_editor::{Element, Position, SyntaxEditor},
13};
14
15use crate::assist_context::{AssistContext, Assists};
16
17// Assist: convert_range_for_to_while
18//
19// Convert for each range into while loop.
20//
21// ```
22// fn foo() {
23//     $0for i in 3..7 {
24//         foo(i);
25//     }
26// }
27// ```
28// ->
29// ```
30// fn foo() {
31//     let mut i = 3;
32//     while i < 7 {
33//         foo(i);
34//         i += 1;
35//     }
36// }
37// ```
38pub(crate) fn convert_range_for_to_while(
39    acc: &mut Assists,
40    ctx: &AssistContext<'_, '_>,
41) -> Option<()> {
42    let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone());
43    let make = editor.make();
44    let for_kw = ctx.find_token_syntax_at_offset(T![for])?;
45    let for_ = ast::ForExpr::cast(for_kw.parent()?)?;
46    let ast::Pat::IdentPat(pat) = for_.pat()? else { return None };
47    let iterable = for_.iterable()?;
48    let (start, end, step, inclusive) = extract_range(&iterable, make)?;
49    let name = pat.name()?;
50    let body = for_.loop_body()?.stmt_list()?;
51    let label = for_.label();
52
53    let description = if end.is_some() {
54        "Replace with while expression"
55    } else {
56        "Replace with loop expression"
57    };
58    acc.add(
59        AssistId::refactor("convert_range_for_to_while"),
60        description,
61        for_.syntax().text_range(),
62        |builder| {
63            let make = editor.make();
64            let indent = for_.indent_level();
65            let pat = make.ident_pat(pat.ref_token().is_some(), true, name.clone());
66            let let_stmt = make.let_stmt(pat.into(), None, Some(start));
67            editor.insert_all(
68                Position::before(for_.syntax()),
69                vec![
70                    let_stmt.syntax().syntax_element(),
71                    make.whitespace(&format!("\n{}", indent)).syntax_element(),
72                ],
73            );
74
75            let mut elements = vec![];
76
77            let var_expr = make.expr_path(make.ident_path(&name.text()));
78            let op = ast::BinaryOp::CmpOp(ast::CmpOp::Ord {
79                ordering: ast::Ordering::Less,
80                strict: !inclusive,
81            });
82            if let Some(end) = end {
83                elements.extend([
84                    make.token(T![while]).syntax_element(),
85                    make.whitespace(" ").syntax_element(),
86                    make.expr_bin(var_expr.clone(), op, end).syntax().syntax_element(),
87                ]);
88            } else {
89                elements.push(make.token(T![loop]).syntax_element());
90            }
91
92            editor.replace_all(
93                for_kw.syntax_element()..=iterable.syntax().syntax_element(),
94                elements,
95            );
96
97            let op = ast::BinaryOp::Assignment { op: Some(ast::ArithOp::Add) };
98            let incrementer = vec![
99                make.whitespace(&format!("\n{}", indent + 1)).syntax_element(),
100                make.expr_bin(var_expr, op, step).syntax().syntax_element(),
101                make.token(T![;]).syntax_element(),
102            ];
103            process_loop_body(body, label, &editor, incrementer);
104            builder.add_file_edits(ctx.vfs_file_id(), editor);
105        },
106    )
107}
108
109fn extract_range(
110    iterable: &ast::Expr,
111    make: &SyntaxFactory,
112) -> Option<(ast::Expr, Option<ast::Expr>, ast::Expr, bool)> {
113    Some(match iterable {
114        ast::Expr::ParenExpr(expr) => extract_range(&expr.expr()?, make)?,
115        ast::Expr::RangeExpr(range) => {
116            let inclusive = range.op_kind()? == ast::RangeOp::Inclusive;
117            (range.start()?, range.end(), make.expr_literal("1").into(), inclusive)
118        }
119        ast::Expr::MethodCallExpr(call) if call.name_ref()?.text() == "step_by" => {
120            let [step] = Itertools::collect_array(call.arg_list()?.args())?;
121            let (start, end, _, inclusive) = extract_range(&call.receiver()?, make)?;
122            (start, end, step, inclusive)
123        }
124        _ => return None,
125    })
126}
127
128fn process_loop_body(
129    body: ast::StmtList,
130    label: Option<ast::Label>,
131    editor: &SyntaxEditor,
132    incrementer: Vec<SyntaxElement>,
133) -> Option<()> {
134    let make = editor.make();
135    let last = previous_non_trivia_token(body.r_curly_token()?)?.syntax_element();
136
137    let new_body = body.indent(1.into());
138    let mut continues = vec![];
139    collect_continue_to(
140        &mut continues,
141        &label.and_then(|it| it.lifetime()),
142        new_body.syntax(),
143        false,
144    );
145
146    if continues.is_empty() {
147        editor.insert_all(Position::after(last), incrementer);
148        return Some(());
149    }
150
151    let mut children = body
152        .syntax()
153        .children_with_tokens()
154        .filter(|it| !matches!(it.kind(), WHITESPACE | T!['{'] | T!['}']));
155    let first = children.next()?;
156    let block_content = first.clone()..=children.last().unwrap_or(first);
157
158    let continue_label = make.lifetime("'cont");
159    let break_expr = make.expr_break(Some(continue_label.clone()), None);
160    let (new_edit, _) = SyntaxEditor::new(new_body.syntax().clone());
161    for continue_expr in &continues {
162        new_edit.replace(continue_expr.syntax(), break_expr.syntax());
163    }
164    let new_body = new_edit.finish().new_root().clone();
165    let elements = itertools::chain(
166        [
167            continue_label.syntax().syntax_element(),
168            make.token(T![:]).syntax_element(),
169            make.whitespace(" ").syntax_element(),
170            new_body.syntax_element(),
171        ],
172        incrementer,
173    );
174    editor.replace_all(block_content, elements.collect());
175
176    Some(())
177}
178
179fn collect_continue_to(
180    acc: &mut Vec<ast::ContinueExpr>,
181    label: &Option<ast::Lifetime>,
182    node: &syntax::SyntaxNode,
183    only_label: bool,
184) {
185    let match_label = |it: &Option<ast::Lifetime>, label: &Option<ast::Lifetime>| match (it, label)
186    {
187        (None, _) => !only_label,
188        (Some(a), Some(b)) if a.text() == b.text() => true,
189        _ => false,
190    };
191    if let Some(expr) = ast::ContinueExpr::cast(node.clone())
192        && match_label(&expr.lifetime(), label)
193    {
194        acc.push(expr);
195    } else if let Some(any_loop) = ast::AnyHasLoopBody::cast(node.clone()) {
196        if match_label(label, &any_loop.label().and_then(|it| it.lifetime())) {
197            return;
198        }
199        for children in node.children() {
200            collect_continue_to(acc, label, &children, true);
201        }
202    } else {
203        for children in node.children() {
204            collect_continue_to(acc, label, &children, only_label);
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::tests::{check_assist, check_assist_not_applicable};
212
213    use super::*;
214
215    #[test]
216    fn test_convert_range_for_to_while() {
217        check_assist(
218            convert_range_for_to_while,
219            "
220fn foo() {
221    $0for i in 3..7 {
222        foo(i);
223    }
224}
225            ",
226            "
227fn foo() {
228    let mut i = 3;
229    while i < 7 {
230        foo(i);
231        i += 1;
232    }
233}
234            ",
235        );
236    }
237
238    #[test]
239    fn test_convert_range_for_to_while_no_end_bound() {
240        check_assist(
241            convert_range_for_to_while,
242            "
243fn foo() {
244    $0for i in 3.. {
245        foo(i);
246    }
247}
248            ",
249            "
250fn foo() {
251    let mut i = 3;
252    loop {
253        foo(i);
254        i += 1;
255    }
256}
257            ",
258        );
259    }
260
261    #[test]
262    fn test_convert_range_for_to_while_with_mut_binding() {
263        check_assist(
264            convert_range_for_to_while,
265            "
266fn foo() {
267    $0for mut i in 3..7 {
268        foo(i);
269    }
270}
271            ",
272            "
273fn foo() {
274    let mut i = 3;
275    while i < 7 {
276        foo(i);
277        i += 1;
278    }
279}
280            ",
281        );
282    }
283
284    #[test]
285    fn test_convert_range_for_to_while_with_label() {
286        check_assist(
287            convert_range_for_to_while,
288            "
289fn foo() {
290    'a: $0for mut i in 3..7 {
291        foo(i);
292    }
293}
294            ",
295            "
296fn foo() {
297    let mut i = 3;
298    'a: while i < 7 {
299        foo(i);
300        i += 1;
301    }
302}
303            ",
304        );
305    }
306
307    #[test]
308    fn test_convert_range_for_to_while_with_continue() {
309        check_assist(
310            convert_range_for_to_while,
311            "
312fn foo() {
313    $0for mut i in 3..7 {
314        foo(i);
315        continue;
316        loop { break; continue }
317        bar(i);
318    }
319}
320            ",
321            "
322fn foo() {
323    let mut i = 3;
324    while i < 7 {
325        'cont: {
326            foo(i);
327            break 'cont;
328            loop { break; continue }
329            bar(i);
330        }
331        i += 1;
332    }
333}
334            ",
335        );
336
337        check_assist(
338            convert_range_for_to_while,
339            "
340fn foo() {
341    'x: $0for mut i in 3..7 {
342        foo(i);
343        continue 'x;
344        loop { break; continue 'x }
345        'x: loop { continue 'x }
346        bar(i);
347    }
348}
349            ",
350            "
351fn foo() {
352    let mut i = 3;
353    'x: while i < 7 {
354        'cont: {
355            foo(i);
356            break 'cont;
357            loop { break; break 'cont }
358            'x: loop { continue 'x }
359            bar(i);
360        }
361        i += 1;
362    }
363}
364            ",
365        );
366    }
367
368    #[test]
369    fn test_convert_range_for_to_while_step_by() {
370        check_assist(
371            convert_range_for_to_while,
372            "
373fn foo() {
374    $0for mut i in (3..7).step_by(2) {
375        foo(i);
376    }
377}
378            ",
379            "
380fn foo() {
381    let mut i = 3;
382    while i < 7 {
383        foo(i);
384        i += 2;
385    }
386}
387            ",
388        );
389    }
390
391    #[test]
392    fn test_convert_range_for_to_while_not_applicable_non_range() {
393        check_assist_not_applicable(
394            convert_range_for_to_while,
395            "
396fn foo() {
397    let ident = 3..7;
398    $0for mut i in ident {
399        foo(i);
400    }
401}
402            ",
403        );
404    }
405}