1use ide_db::assists::AssistId;
2use itertools::Itertools;
3use syntax::{
4 AstNode, SyntaxElement,
5 SyntaxKind::WHITESPACE,
6 T,
7 algo::previous_non_trivia_token,
8 ast::{
9 self, HasArgList, HasLoopBody, HasName, RangeItem, edit::AstNodeEdit,
10 syntax_factory::SyntaxFactory,
11 },
12 syntax_editor::{Element, Position, SyntaxEditor},
13};
14
15use crate::assist_context::{AssistContext, Assists};
16
17pub(crate) fn convert_range_for_to_while(
39 acc: &mut Assists,
40 ctx: &AssistContext<'_, '_>,
41) -> Option<()> {
42 let (editor, _) = SyntaxEditor::new(ctx.source_file().syntax().clone());
43 let make = editor.make();
44 let for_kw = ctx.find_token_syntax_at_offset(T![for])?;
45 let for_ = ast::ForExpr::cast(for_kw.parent()?)?;
46 let ast::Pat::IdentPat(pat) = for_.pat()? else { return None };
47 let iterable = for_.iterable()?;
48 let (start, end, step, inclusive) = extract_range(&iterable, make)?;
49 let name = pat.name()?;
50 let body = for_.loop_body()?.stmt_list()?;
51 let label = for_.label();
52
53 let description = if end.is_some() {
54 "Replace with while expression"
55 } else {
56 "Replace with loop expression"
57 };
58 acc.add(
59 AssistId::refactor("convert_range_for_to_while"),
60 description,
61 for_.syntax().text_range(),
62 |builder| {
63 let make = editor.make();
64 let indent = for_.indent_level();
65 let pat = make.ident_pat(pat.ref_token().is_some(), true, name.clone());
66 let let_stmt = make.let_stmt(pat.into(), None, Some(start));
67 editor.insert_all(
68 Position::before(for_.syntax()),
69 vec![
70 let_stmt.syntax().syntax_element(),
71 make.whitespace(&format!("\n{}", indent)).syntax_element(),
72 ],
73 );
74
75 let mut elements = vec![];
76
77 let var_expr = make.expr_path(make.ident_path(&name.text()));
78 let op = ast::BinaryOp::CmpOp(ast::CmpOp::Ord {
79 ordering: ast::Ordering::Less,
80 strict: !inclusive,
81 });
82 if let Some(end) = end {
83 elements.extend([
84 make.token(T![while]).syntax_element(),
85 make.whitespace(" ").syntax_element(),
86 make.expr_bin(var_expr.clone(), op, end).syntax().syntax_element(),
87 ]);
88 } else {
89 elements.push(make.token(T![loop]).syntax_element());
90 }
91
92 editor.replace_all(
93 for_kw.syntax_element()..=iterable.syntax().syntax_element(),
94 elements,
95 );
96
97 let op = ast::BinaryOp::Assignment { op: Some(ast::ArithOp::Add) };
98 let incrementer = vec![
99 make.whitespace(&format!("\n{}", indent + 1)).syntax_element(),
100 make.expr_bin(var_expr, op, step).syntax().syntax_element(),
101 make.token(T![;]).syntax_element(),
102 ];
103 process_loop_body(body, label, &editor, incrementer);
104 builder.add_file_edits(ctx.vfs_file_id(), editor);
105 },
106 )
107}
108
109fn extract_range(
110 iterable: &ast::Expr,
111 make: &SyntaxFactory,
112) -> Option<(ast::Expr, Option<ast::Expr>, ast::Expr, bool)> {
113 Some(match iterable {
114 ast::Expr::ParenExpr(expr) => extract_range(&expr.expr()?, make)?,
115 ast::Expr::RangeExpr(range) => {
116 let inclusive = range.op_kind()? == ast::RangeOp::Inclusive;
117 (range.start()?, range.end(), make.expr_literal("1").into(), inclusive)
118 }
119 ast::Expr::MethodCallExpr(call) if call.name_ref()?.text() == "step_by" => {
120 let [step] = Itertools::collect_array(call.arg_list()?.args())?;
121 let (start, end, _, inclusive) = extract_range(&call.receiver()?, make)?;
122 (start, end, step, inclusive)
123 }
124 _ => return None,
125 })
126}
127
128fn process_loop_body(
129 body: ast::StmtList,
130 label: Option<ast::Label>,
131 editor: &SyntaxEditor,
132 incrementer: Vec<SyntaxElement>,
133) -> Option<()> {
134 let make = editor.make();
135 let last = previous_non_trivia_token(body.r_curly_token()?)?.syntax_element();
136
137 let new_body = body.indent(1.into());
138 let mut continues = vec![];
139 collect_continue_to(
140 &mut continues,
141 &label.and_then(|it| it.lifetime()),
142 new_body.syntax(),
143 false,
144 );
145
146 if continues.is_empty() {
147 editor.insert_all(Position::after(last), incrementer);
148 return Some(());
149 }
150
151 let mut children = body
152 .syntax()
153 .children_with_tokens()
154 .filter(|it| !matches!(it.kind(), WHITESPACE | T!['{'] | T!['}']));
155 let first = children.next()?;
156 let block_content = first.clone()..=children.last().unwrap_or(first);
157
158 let continue_label = make.lifetime("'cont");
159 let break_expr = make.expr_break(Some(continue_label.clone()), None);
160 let (new_edit, _) = SyntaxEditor::new(new_body.syntax().clone());
161 for continue_expr in &continues {
162 new_edit.replace(continue_expr.syntax(), break_expr.syntax());
163 }
164 let new_body = new_edit.finish().new_root().clone();
165 let elements = itertools::chain(
166 [
167 continue_label.syntax().syntax_element(),
168 make.token(T![:]).syntax_element(),
169 make.whitespace(" ").syntax_element(),
170 new_body.syntax_element(),
171 ],
172 incrementer,
173 );
174 editor.replace_all(block_content, elements.collect());
175
176 Some(())
177}
178
179fn collect_continue_to(
180 acc: &mut Vec<ast::ContinueExpr>,
181 label: &Option<ast::Lifetime>,
182 node: &syntax::SyntaxNode,
183 only_label: bool,
184) {
185 let match_label = |it: &Option<ast::Lifetime>, label: &Option<ast::Lifetime>| match (it, label)
186 {
187 (None, _) => !only_label,
188 (Some(a), Some(b)) if a.text() == b.text() => true,
189 _ => false,
190 };
191 if let Some(expr) = ast::ContinueExpr::cast(node.clone())
192 && match_label(&expr.lifetime(), label)
193 {
194 acc.push(expr);
195 } else if let Some(any_loop) = ast::AnyHasLoopBody::cast(node.clone()) {
196 if match_label(label, &any_loop.label().and_then(|it| it.lifetime())) {
197 return;
198 }
199 for children in node.children() {
200 collect_continue_to(acc, label, &children, true);
201 }
202 } else {
203 for children in node.children() {
204 collect_continue_to(acc, label, &children, only_label);
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use crate::tests::{check_assist, check_assist_not_applicable};
212
213 use super::*;
214
215 #[test]
216 fn test_convert_range_for_to_while() {
217 check_assist(
218 convert_range_for_to_while,
219 "
220fn foo() {
221 $0for i in 3..7 {
222 foo(i);
223 }
224}
225 ",
226 "
227fn foo() {
228 let mut i = 3;
229 while i < 7 {
230 foo(i);
231 i += 1;
232 }
233}
234 ",
235 );
236 }
237
238 #[test]
239 fn test_convert_range_for_to_while_no_end_bound() {
240 check_assist(
241 convert_range_for_to_while,
242 "
243fn foo() {
244 $0for i in 3.. {
245 foo(i);
246 }
247}
248 ",
249 "
250fn foo() {
251 let mut i = 3;
252 loop {
253 foo(i);
254 i += 1;
255 }
256}
257 ",
258 );
259 }
260
261 #[test]
262 fn test_convert_range_for_to_while_with_mut_binding() {
263 check_assist(
264 convert_range_for_to_while,
265 "
266fn foo() {
267 $0for mut i in 3..7 {
268 foo(i);
269 }
270}
271 ",
272 "
273fn foo() {
274 let mut i = 3;
275 while i < 7 {
276 foo(i);
277 i += 1;
278 }
279}
280 ",
281 );
282 }
283
284 #[test]
285 fn test_convert_range_for_to_while_with_label() {
286 check_assist(
287 convert_range_for_to_while,
288 "
289fn foo() {
290 'a: $0for mut i in 3..7 {
291 foo(i);
292 }
293}
294 ",
295 "
296fn foo() {
297 let mut i = 3;
298 'a: while i < 7 {
299 foo(i);
300 i += 1;
301 }
302}
303 ",
304 );
305 }
306
307 #[test]
308 fn test_convert_range_for_to_while_with_continue() {
309 check_assist(
310 convert_range_for_to_while,
311 "
312fn foo() {
313 $0for mut i in 3..7 {
314 foo(i);
315 continue;
316 loop { break; continue }
317 bar(i);
318 }
319}
320 ",
321 "
322fn foo() {
323 let mut i = 3;
324 while i < 7 {
325 'cont: {
326 foo(i);
327 break 'cont;
328 loop { break; continue }
329 bar(i);
330 }
331 i += 1;
332 }
333}
334 ",
335 );
336
337 check_assist(
338 convert_range_for_to_while,
339 "
340fn foo() {
341 'x: $0for mut i in 3..7 {
342 foo(i);
343 continue 'x;
344 loop { break; continue 'x }
345 'x: loop { continue 'x }
346 bar(i);
347 }
348}
349 ",
350 "
351fn foo() {
352 let mut i = 3;
353 'x: while i < 7 {
354 'cont: {
355 foo(i);
356 break 'cont;
357 loop { break; break 'cont }
358 'x: loop { continue 'x }
359 bar(i);
360 }
361 i += 1;
362 }
363}
364 ",
365 );
366 }
367
368 #[test]
369 fn test_convert_range_for_to_while_step_by() {
370 check_assist(
371 convert_range_for_to_while,
372 "
373fn foo() {
374 $0for mut i in (3..7).step_by(2) {
375 foo(i);
376 }
377}
378 ",
379 "
380fn foo() {
381 let mut i = 3;
382 while i < 7 {
383 foo(i);
384 i += 2;
385 }
386}
387 ",
388 );
389 }
390
391 #[test]
392 fn test_convert_range_for_to_while_not_applicable_non_range() {
393 check_assist_not_applicable(
394 convert_range_for_to_while,
395 "
396fn foo() {
397 let ident = 3..7;
398 $0for mut i in ident {
399 foo(i);
400 }
401}
402 ",
403 );
404 }
405}