ide_assists/handlers/
convert_iter_for_each_to_for.rs

1use hir::{Name, sym};
2use ide_db::famous_defs::FamousDefs;
3use stdx::format_to;
4use syntax::{
5    AstNode,
6    ast::{self, HasArgList, HasLoopBody, edit_in_place::Indent, syntax_factory::SyntaxFactory},
7};
8
9use crate::{AssistContext, AssistId, Assists};
10
11// Assist: convert_iter_for_each_to_for
12//
13// Converts an Iterator::for_each function into a for loop.
14//
15// ```
16// # //- minicore: iterators
17// # use core::iter;
18// fn main() {
19//     let iter = iter::repeat((9, 2));
20//     iter.for_each$0(|(x, y)| {
21//         println!("x: {}, y: {}", x, y);
22//     });
23// }
24// ```
25// ->
26// ```
27// # use core::iter;
28// fn main() {
29//     let iter = iter::repeat((9, 2));
30//     for (x, y) in iter {
31//         println!("x: {}, y: {}", x, y);
32//     }
33// }
34// ```
35pub(crate) fn convert_iter_for_each_to_for(
36    acc: &mut Assists,
37    ctx: &AssistContext<'_>,
38) -> Option<()> {
39    let method = ctx.find_node_at_offset::<ast::MethodCallExpr>()?;
40
41    let closure = match method.arg_list()?.args().next()? {
42        ast::Expr::ClosureExpr(expr) => expr,
43        _ => return None,
44    };
45
46    let (method, receiver) = validate_method_call_expr(ctx, method)?;
47
48    let param_list = closure.param_list()?;
49    let param = param_list.params().next()?.pat()?;
50    let body = closure.body()?;
51
52    let stmt = method.syntax().parent().and_then(ast::ExprStmt::cast);
53    let range = stmt.as_ref().map_or(method.syntax(), AstNode::syntax).text_range();
54
55    acc.add(
56        AssistId::refactor_rewrite("convert_iter_for_each_to_for"),
57        "Replace this `Iterator::for_each` with a for loop",
58        range,
59        |builder| {
60            let make = SyntaxFactory::with_mappings();
61            let indent =
62                stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level);
63
64            let block = match body {
65                ast::Expr::BlockExpr(block) => block.clone_for_update(),
66                _ => make.block_expr(Vec::new(), Some(body)),
67            };
68            block.reindent_to(indent);
69
70            let expr_for_loop = make.expr_for_loop(param, receiver, block);
71
72            let target_node = stmt.as_ref().map_or(method.syntax(), AstNode::syntax);
73            let mut editor = builder.make_editor(target_node);
74            editor.replace(target_node, expr_for_loop.syntax());
75            builder.add_file_edits(ctx.vfs_file_id(), editor);
76        },
77    )
78}
79
80// Assist: convert_for_loop_with_for_each
81//
82// Converts a for loop into a for_each loop on the Iterator.
83//
84// ```
85// fn main() {
86//     let x = vec![1, 2, 3];
87//     for$0 v in x {
88//         let y = v * 2;
89//     }
90// }
91// ```
92// ->
93// ```
94// fn main() {
95//     let x = vec![1, 2, 3];
96//     x.into_iter().for_each(|v| {
97//         let y = v * 2;
98//     });
99// }
100// ```
101pub(crate) fn convert_for_loop_with_for_each(
102    acc: &mut Assists,
103    ctx: &AssistContext<'_>,
104) -> Option<()> {
105    let for_loop = ctx.find_node_at_offset::<ast::ForExpr>()?;
106    let iterable = for_loop.iterable()?;
107    let pat = for_loop.pat()?;
108    let body = for_loop.loop_body()?;
109    if body.syntax().text_range().start() < ctx.offset() {
110        cov_mark::hit!(not_available_in_body);
111        return None;
112    }
113
114    acc.add(
115        AssistId::refactor_rewrite("convert_for_loop_with_for_each"),
116        "Replace this for loop with `Iterator::for_each`",
117        for_loop.syntax().text_range(),
118        |builder| {
119            let mut buf = String::new();
120
121            if let Some((expr_behind_ref, method, krate)) =
122                is_ref_and_impls_iter_method(&ctx.sema, &iterable)
123            {
124                // We have either "for x in &col" and col implements a method called iter
125                //             or "for x in &mut col" and col implements a method called iter_mut
126                format_to!(
127                    buf,
128                    "{expr_behind_ref}.{}()",
129                    method.display(ctx.db(), krate.edition(ctx.db()))
130                );
131            } else if let ast::Expr::RangeExpr(..) = iterable {
132                // range expressions need to be parenthesized for the syntax to be correct
133                format_to!(buf, "({iterable})");
134            } else if impls_core_iter(&ctx.sema, &iterable) {
135                format_to!(buf, "{iterable}");
136            } else if let ast::Expr::RefExpr(_) = iterable {
137                format_to!(buf, "({iterable}).into_iter()");
138            } else {
139                format_to!(buf, "{iterable}.into_iter()");
140            }
141
142            format_to!(buf, ".for_each(|{pat}| {body});");
143
144            builder.replace(for_loop.syntax().text_range(), buf)
145        },
146    )
147}
148
149/// If iterable is a reference where the expression behind the reference implements a method
150/// returning an Iterator called iter or iter_mut (depending on the type of reference) then return
151/// the expression behind the reference and the method name
152fn is_ref_and_impls_iter_method(
153    sema: &hir::Semantics<'_, ide_db::RootDatabase>,
154    iterable: &ast::Expr,
155) -> Option<(ast::Expr, hir::Name, hir::Crate)> {
156    let ref_expr = match iterable {
157        ast::Expr::RefExpr(r) => r,
158        _ => return None,
159    };
160    let wanted_method = Name::new_symbol_root(if ref_expr.mut_token().is_some() {
161        sym::iter_mut
162    } else {
163        sym::iter
164    });
165    let expr_behind_ref = ref_expr.expr()?;
166    let ty = sema.type_of_expr(&expr_behind_ref)?.adjusted();
167    let scope = sema.scope(iterable.syntax())?;
168    let krate = scope.krate();
169    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
170
171    let has_wanted_method = ty
172        .iterate_method_candidates(sema.db, &scope, Some(&wanted_method), |func| {
173            if func.ret_type(sema.db).impls_trait(sema.db, iter_trait, &[]) {
174                return Some(());
175            }
176            None
177        })
178        .is_some();
179    if !has_wanted_method {
180        return None;
181    }
182
183    Some((expr_behind_ref, wanted_method, krate))
184}
185
186/// Whether iterable implements core::Iterator
187fn impls_core_iter(sema: &hir::Semantics<'_, ide_db::RootDatabase>, iterable: &ast::Expr) -> bool {
188    (|| {
189        let it_typ = sema.type_of_expr(iterable)?.adjusted();
190
191        let module = sema.scope(iterable.syntax())?.module();
192
193        let krate = module.krate(sema.db);
194        let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
195        cov_mark::hit!(test_already_impls_iterator);
196        Some(it_typ.impls_trait(sema.db, iter_trait, &[]))
197    })()
198    .unwrap_or(false)
199}
200
201fn validate_method_call_expr(
202    ctx: &AssistContext<'_>,
203    expr: ast::MethodCallExpr,
204) -> Option<(ast::Expr, ast::Expr)> {
205    let name_ref = expr.name_ref()?;
206    if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) {
207        cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos);
208        return None;
209    }
210    if name_ref.text() != "for_each" {
211        return None;
212    }
213
214    let sema = &ctx.sema;
215
216    let receiver = expr.receiver()?;
217    let expr = ast::Expr::MethodCallExpr(expr);
218
219    let it_type = sema.type_of_expr(&receiver)?.adjusted();
220    let module = sema.scope(receiver.syntax())?.module();
221    let krate = module.krate(ctx.db());
222
223    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
224    it_type.impls_trait(sema.db, iter_trait, &[]).then_some((expr, receiver))
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::tests::{check_assist, check_assist_not_applicable};
230
231    use super::*;
232
233    #[test]
234    fn test_for_each_in_method_stmt() {
235        check_assist(
236            convert_iter_for_each_to_for,
237            r#"
238//- minicore: iterators
239fn main() {
240    let it = core::iter::repeat(92);
241    it.$0for_each(|(x, y)| {
242        println!("x: {}, y: {}", x, y);
243    });
244}
245"#,
246            r#"
247fn main() {
248    let it = core::iter::repeat(92);
249    for (x, y) in it {
250        println!("x: {}, y: {}", x, y);
251    }
252}
253"#,
254        )
255    }
256
257    #[test]
258    fn test_for_each_in_method() {
259        check_assist(
260            convert_iter_for_each_to_for,
261            r#"
262//- minicore: iterators
263fn main() {
264    let it = core::iter::repeat(92);
265    it.$0for_each(|(x, y)| {
266        println!("x: {}, y: {}", x, y);
267    })
268}
269"#,
270            r#"
271fn main() {
272    let it = core::iter::repeat(92);
273    for (x, y) in it {
274        println!("x: {}, y: {}", x, y);
275    }
276}
277"#,
278        )
279    }
280
281    #[test]
282    fn test_for_each_without_braces_stmt() {
283        check_assist(
284            convert_iter_for_each_to_for,
285            r#"
286//- minicore: iterators
287fn main() {
288    let it = core::iter::repeat(92);
289    it.$0for_each(|(x, y)| println!("x: {}, y: {}", x, y));
290}
291"#,
292            r#"
293fn main() {
294    let it = core::iter::repeat(92);
295    for (x, y) in it {
296        println!("x: {}, y: {}", x, y)
297    }
298}
299"#,
300        )
301    }
302
303    #[test]
304    fn test_for_each_not_applicable() {
305        check_assist_not_applicable(
306            convert_iter_for_each_to_for,
307            r#"
308//- minicore: iterators
309fn main() {
310    ().$0for_each(|x| println!("{}", x));
311}"#,
312        )
313    }
314
315    #[test]
316    fn test_for_each_not_applicable_invalid_cursor_pos() {
317        cov_mark::check!(test_for_each_not_applicable_invalid_cursor_pos);
318        check_assist_not_applicable(
319            convert_iter_for_each_to_for,
320            r#"
321//- minicore: iterators
322fn main() {
323    core::iter::repeat(92).for_each(|(x, y)| $0println!("x: {}, y: {}", x, y));
324}"#,
325        )
326    }
327
328    #[test]
329    fn each_to_for_not_for() {
330        check_assist_not_applicable(
331            convert_for_loop_with_for_each,
332            r"
333let mut x = vec![1, 2, 3];
334x.iter_mut().$0for_each(|v| *v *= 2);
335        ",
336        )
337    }
338
339    #[test]
340    fn each_to_for_simple_for() {
341        check_assist(
342            convert_for_loop_with_for_each,
343            r"
344fn main() {
345    let x = vec![1, 2, 3];
346    for $0v in x {
347        v *= 2;
348    }
349}",
350            r"
351fn main() {
352    let x = vec![1, 2, 3];
353    x.into_iter().for_each(|v| {
354        v *= 2;
355    });
356}",
357        )
358    }
359
360    #[test]
361    fn each_to_for_for_in_range() {
362        check_assist(
363            convert_for_loop_with_for_each,
364            r#"
365//- minicore: range, iterators
366impl<T> core::iter::Iterator for core::ops::Range<T> {
367    type Item = T;
368
369    fn next(&mut self) -> Option<Self::Item> {
370        None
371    }
372}
373
374fn main() {
375    for $0x in 0..92 {
376        print!("{}", x);
377    }
378}"#,
379            r#"
380impl<T> core::iter::Iterator for core::ops::Range<T> {
381    type Item = T;
382
383    fn next(&mut self) -> Option<Self::Item> {
384        None
385    }
386}
387
388fn main() {
389    (0..92).for_each(|x| {
390        print!("{}", x);
391    });
392}"#,
393        )
394    }
395
396    #[test]
397    fn each_to_for_not_available_in_body() {
398        cov_mark::check!(not_available_in_body);
399        check_assist_not_applicable(
400            convert_for_loop_with_for_each,
401            r"
402fn main() {
403    let x = vec![1, 2, 3];
404    for v in x {
405        $0v *= 2;
406    }
407}",
408        )
409    }
410
411    #[test]
412    fn each_to_for_for_borrowed() {
413        check_assist(
414            convert_for_loop_with_for_each,
415            r#"
416//- minicore: iterators
417use core::iter::{Repeat, repeat};
418
419struct S;
420impl S {
421    fn iter(&self) -> Repeat<i32> { repeat(92) }
422    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
423}
424
425fn main() {
426    let x = S;
427    for $0v in &x {
428        let a = v * 2;
429    }
430}
431"#,
432            r#"
433use core::iter::{Repeat, repeat};
434
435struct S;
436impl S {
437    fn iter(&self) -> Repeat<i32> { repeat(92) }
438    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
439}
440
441fn main() {
442    let x = S;
443    x.iter().for_each(|v| {
444        let a = v * 2;
445    });
446}
447"#,
448        )
449    }
450
451    #[test]
452    fn each_to_for_for_borrowed_no_iter_method() {
453        check_assist(
454            convert_for_loop_with_for_each,
455            r"
456struct NoIterMethod;
457fn main() {
458    let x = NoIterMethod;
459    for $0v in &x {
460        let a = v * 2;
461    }
462}
463",
464            r"
465struct NoIterMethod;
466fn main() {
467    let x = NoIterMethod;
468    (&x).into_iter().for_each(|v| {
469        let a = v * 2;
470    });
471}
472",
473        )
474    }
475
476    #[test]
477    fn each_to_for_for_borrowed_mut() {
478        check_assist(
479            convert_for_loop_with_for_each,
480            r#"
481//- minicore: iterators
482use core::iter::{Repeat, repeat};
483
484struct S;
485impl S {
486    fn iter(&self) -> Repeat<i32> { repeat(92) }
487    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
488}
489
490fn main() {
491    let x = S;
492    for $0v in &mut x {
493        let a = v * 2;
494    }
495}
496"#,
497            r#"
498use core::iter::{Repeat, repeat};
499
500struct S;
501impl S {
502    fn iter(&self) -> Repeat<i32> { repeat(92) }
503    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
504}
505
506fn main() {
507    let x = S;
508    x.iter_mut().for_each(|v| {
509        let a = v * 2;
510    });
511}
512"#,
513        )
514    }
515
516    #[test]
517    fn each_to_for_for_borrowed_mut_behind_var() {
518        check_assist(
519            convert_for_loop_with_for_each,
520            r"
521fn main() {
522    let x = vec![1, 2, 3];
523    let y = &mut x;
524    for $0v in y {
525        *v *= 2;
526    }
527}",
528            r"
529fn main() {
530    let x = vec![1, 2, 3];
531    let y = &mut x;
532    y.into_iter().for_each(|v| {
533        *v *= 2;
534    });
535}",
536        )
537    }
538
539    #[test]
540    fn each_to_for_already_impls_iterator() {
541        cov_mark::check!(test_already_impls_iterator);
542        check_assist(
543            convert_for_loop_with_for_each,
544            r#"
545//- minicore: iterators
546fn main() {
547    for$0 a in core::iter::repeat(92).take(1) {
548        println!("{}", a);
549    }
550}
551"#,
552            r#"
553fn main() {
554    core::iter::repeat(92).take(1).for_each(|a| {
555        println!("{}", a);
556    });
557}
558"#,
559        );
560    }
561}