Skip to main content

ide_assists/handlers/
convert_iter_for_each_to_for.rs

1use hir::{Name, sym};
2use ide_db::famous_defs::FamousDefs;
3use syntax::{
4    AstNode,
5    ast::{self, HasArgList, HasLoopBody, edit::AstNodeEdit},
6};
7
8use crate::{AssistContext, AssistId, Assists, utils::wrap_paren};
9
10// Assist: convert_iter_for_each_to_for
11//
12// Converts an Iterator::for_each function into a for loop.
13//
14// ```
15// # //- minicore: iterators
16// # use core::iter;
17// fn main() {
18//     let iter = iter::repeat((9, 2));
19//     iter.for_each$0(|(x, y)| {
20//         println!("x: {}, y: {}", x, y);
21//     });
22// }
23// ```
24// ->
25// ```
26// # use core::iter;
27// fn main() {
28//     let iter = iter::repeat((9, 2));
29//     for (x, y) in iter {
30//         println!("x: {}, y: {}", x, y);
31//     }
32// }
33// ```
34pub(crate) fn convert_iter_for_each_to_for(
35    acc: &mut Assists,
36    ctx: &AssistContext<'_, '_>,
37) -> Option<()> {
38    let method = ctx.find_node_at_offset::<ast::MethodCallExpr>()?;
39
40    let closure = match method.arg_list()?.args().next()? {
41        ast::Expr::ClosureExpr(expr) => expr,
42        _ => return None,
43    };
44
45    let (method, receiver) = validate_method_call_expr(ctx, method)?;
46
47    let param_list = closure.param_list()?;
48    let param = param_list.params().next()?.pat()?;
49    let body = closure.body()?;
50
51    let stmt = method.syntax().parent().and_then(ast::ExprStmt::cast);
52    let range = stmt.as_ref().map_or(method.syntax(), AstNode::syntax).text_range();
53
54    acc.add(
55        AssistId::refactor_rewrite("convert_iter_for_each_to_for"),
56        "Replace this `Iterator::for_each` with a for loop",
57        range,
58        |builder| {
59            let target_node = stmt.as_ref().map_or(method.syntax(), AstNode::syntax);
60            let editor = builder.make_editor(target_node);
61            let make = editor.make();
62            let indent =
63                stmt.as_ref().map_or_else(|| method.indent_level(), ast::ExprStmt::indent_level);
64
65            let block = match body {
66                ast::Expr::BlockExpr(block) => block.reset_indent(),
67                _ => make.block_expr(Vec::new(), Some(body.reset_indent().indent(1.into()))),
68            }
69            .indent(indent);
70
71            let expr_for_loop = make.expr_for_loop(param, receiver, block);
72            editor.replace(target_node, expr_for_loop.syntax());
73            builder.add_file_edits(ctx.vfs_file_id(), editor);
74        },
75    )
76}
77
78// Assist: convert_for_loop_with_for_each
79//
80// Converts a for loop into a for_each loop on the Iterator.
81//
82// ```
83// fn main() {
84//     let x = vec![1, 2, 3];
85//     for$0 v in x {
86//         let y = v * 2;
87//     }
88// }
89// ```
90// ->
91// ```
92// fn main() {
93//     let x = vec![1, 2, 3];
94//     x.into_iter().for_each(|v| {
95//         let y = v * 2;
96//     });
97// }
98// ```
99pub(crate) fn convert_for_loop_with_for_each(
100    acc: &mut Assists,
101    ctx: &AssistContext<'_, '_>,
102) -> Option<()> {
103    let for_loop = ctx.find_node_at_offset::<ast::ForExpr>()?;
104    let iterable = for_loop.iterable()?;
105    let pat = for_loop.pat()?;
106    let body = for_loop.loop_body()?;
107    if body.syntax().text_range().start() < ctx.offset() {
108        cov_mark::hit!(not_available_in_body);
109        return None;
110    }
111
112    acc.add(
113        AssistId::refactor_rewrite("convert_for_loop_with_for_each"),
114        "Replace this for loop with `Iterator::for_each`",
115        for_loop.syntax().text_range(),
116        |builder| {
117            let editor = builder.make_editor(for_loop.syntax());
118            let make = editor.make();
119
120            let mut receiver = iterable.clone();
121
122            let iter_method = if let Some((expr_behind_ref, method, krate)) =
123                is_ref_and_impls_iter_method(&ctx.sema, &iterable)
124            {
125                receiver = expr_behind_ref;
126                // We have either "for x in &col" and col implements a method called iter
127                //             or "for x in &mut col" and col implements a method called iter_mut
128                method.display(ctx.db(), krate.edition(ctx.db())).to_string()
129            } else {
130                "into_iter".to_owned()
131            };
132
133            receiver = wrap_paren(receiver, make, ast::prec::ExprPrecedence::Postfix);
134
135            if !impls_core_iter(&ctx.sema, &iterable) {
136                receiver = make
137                    .expr_method_call(receiver, make.name_ref(&iter_method), make.arg_list([]))
138                    .into();
139            }
140
141            let loop_arg = make.expr_closure([make.untyped_param(pat)], body.into());
142            let for_each = make.expr_method_call(
143                receiver,
144                make.name_ref("for_each"),
145                make.arg_list([loop_arg.into()]),
146            );
147            let for_each = make.expr_stmt(for_each.into());
148
149            editor.replace(for_loop.syntax(), for_each.syntax());
150            builder.add_file_edits(ctx.vfs_file_id(), editor);
151        },
152    )
153}
154
155/// If iterable is a reference where the expression behind the reference implements a method
156/// returning an Iterator called iter or iter_mut (depending on the type of reference) then return
157/// the expression behind the reference and the method name
158fn is_ref_and_impls_iter_method(
159    sema: &hir::Semantics<'_, ide_db::RootDatabase>,
160    iterable: &ast::Expr,
161) -> Option<(ast::Expr, hir::Name, hir::Crate)> {
162    let ref_expr = match iterable {
163        ast::Expr::RefExpr(r) => r,
164        _ => return None,
165    };
166    let wanted_method = Name::new_symbol_root(if ref_expr.mut_token().is_some() {
167        sym::iter_mut
168    } else {
169        sym::iter
170    });
171    let expr_behind_ref = ref_expr.expr()?;
172    let ty = sema.type_of_expr(&expr_behind_ref)?.adjusted();
173    let scope = sema.scope(iterable.syntax())?;
174    let krate = scope.krate();
175    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
176
177    let has_wanted_method = ty
178        .iterate_method_candidates(sema.db, &scope, Some(&wanted_method), |func| {
179            if func.ret_type(sema.db).impls_trait(sema.db, iter_trait, &[]) {
180                return Some(());
181            }
182            None
183        })
184        .is_some();
185    if !has_wanted_method {
186        return None;
187    }
188
189    Some((expr_behind_ref, wanted_method, krate))
190}
191
192/// Whether iterable implements core::Iterator
193fn impls_core_iter(sema: &hir::Semantics<'_, ide_db::RootDatabase>, iterable: &ast::Expr) -> bool {
194    (|| {
195        let it_typ = sema.type_of_expr(iterable)?.adjusted();
196
197        let module = sema.scope(iterable.syntax())?.module();
198
199        let krate = module.krate(sema.db);
200        let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
201        cov_mark::hit!(test_already_impls_iterator);
202        Some(it_typ.impls_trait(sema.db, iter_trait, &[]))
203    })()
204    .unwrap_or(false)
205}
206
207fn validate_method_call_expr(
208    ctx: &AssistContext<'_, '_>,
209    expr: ast::MethodCallExpr,
210) -> Option<(ast::Expr, ast::Expr)> {
211    let name_ref = expr.name_ref()?;
212    if !name_ref.syntax().text_range().contains_range(ctx.selection_trimmed()) {
213        cov_mark::hit!(test_for_each_not_applicable_invalid_cursor_pos);
214        return None;
215    }
216    if name_ref.text() != "for_each" {
217        return None;
218    }
219
220    let sema = &ctx.sema;
221
222    let receiver = expr.receiver()?;
223    let expr = ast::Expr::MethodCallExpr(expr);
224
225    let it_type = sema.type_of_expr(&receiver)?.adjusted();
226    let module = sema.scope(receiver.syntax())?.module();
227    let krate = module.krate(ctx.db());
228
229    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
230    it_type.impls_trait(sema.db, iter_trait, &[]).then_some((expr, receiver))
231}
232
233#[cfg(test)]
234mod tests {
235    use crate::tests::{check_assist, check_assist_not_applicable};
236
237    use super::*;
238
239    #[test]
240    fn test_for_each_in_method_stmt() {
241        check_assist(
242            convert_iter_for_each_to_for,
243            r#"
244//- minicore: iterators
245fn main() {
246    let it = core::iter::repeat(92);
247    it.$0for_each(|(x, y)| {
248        println!("x: {}, y: {}", x, y);
249    });
250}
251"#,
252            r#"
253fn main() {
254    let it = core::iter::repeat(92);
255    for (x, y) in it {
256        println!("x: {}, y: {}", x, y);
257    }
258}
259"#,
260        )
261    }
262
263    #[test]
264    fn test_for_each_in_method() {
265        check_assist(
266            convert_iter_for_each_to_for,
267            r#"
268//- minicore: iterators
269fn main() {
270    let it = core::iter::repeat(92);
271    it.$0for_each(|(x, y)| {
272        println!("x: {}, y: {}", x, y);
273    })
274}
275"#,
276            r#"
277fn main() {
278    let it = core::iter::repeat(92);
279    for (x, y) in it {
280        println!("x: {}, y: {}", x, y);
281    }
282}
283"#,
284        )
285    }
286
287    #[test]
288    fn test_for_each_without_braces_stmt() {
289        check_assist(
290            convert_iter_for_each_to_for,
291            r#"
292//- minicore: iterators
293fn main() {
294    {
295        let it = core::iter::repeat(92);
296        it.$0for_each(|param| match param {
297            (x, y) => println!("x: {}, y: {}", x, y),
298        });
299    }
300}
301"#,
302            r#"
303fn main() {
304    {
305        let it = core::iter::repeat(92);
306        for param in it {
307            match param {
308                (x, y) => println!("x: {}, y: {}", x, y),
309            }
310        }
311    }
312}
313"#,
314        )
315    }
316
317    #[test]
318    fn test_for_each_not_applicable() {
319        check_assist_not_applicable(
320            convert_iter_for_each_to_for,
321            r#"
322//- minicore: iterators
323fn main() {
324    ().$0for_each(|x| println!("{}", x));
325}"#,
326        )
327    }
328
329    #[test]
330    fn test_for_each_not_applicable_invalid_cursor_pos() {
331        cov_mark::check!(test_for_each_not_applicable_invalid_cursor_pos);
332        check_assist_not_applicable(
333            convert_iter_for_each_to_for,
334            r#"
335//- minicore: iterators
336fn main() {
337    core::iter::repeat(92).for_each(|(x, y)| $0println!("x: {}, y: {}", x, y));
338}"#,
339        )
340    }
341
342    #[test]
343    fn each_to_for_not_for() {
344        check_assist_not_applicable(
345            convert_for_loop_with_for_each,
346            r"
347let mut x = vec![1, 2, 3];
348x.iter_mut().$0for_each(|v| *v *= 2);
349        ",
350        )
351    }
352
353    #[test]
354    fn each_to_for_simple_for() {
355        check_assist(
356            convert_for_loop_with_for_each,
357            r"
358fn main() {
359    let x = vec![1, 2, 3];
360    for $0v in x {
361        v *= 2;
362    }
363}",
364            r"
365fn main() {
366    let x = vec![1, 2, 3];
367    x.into_iter().for_each(|v| {
368        v *= 2;
369    });
370}",
371        )
372    }
373
374    #[test]
375    fn each_to_for_for_in_range() {
376        check_assist(
377            convert_for_loop_with_for_each,
378            r#"
379//- minicore: range, iterators
380impl<T> core::iter::Iterator for core::ops::Range<T> {
381    type Item = T;
382
383    fn next(&mut self) -> Option<Self::Item> {
384        None
385    }
386}
387
388fn main() {
389    for $0x in 0..92 {
390        print!("{}", x);
391    }
392}"#,
393            r#"
394impl<T> core::iter::Iterator for core::ops::Range<T> {
395    type Item = T;
396
397    fn next(&mut self) -> Option<Self::Item> {
398        None
399    }
400}
401
402fn main() {
403    (0..92).for_each(|x| {
404        print!("{}", x);
405    });
406}"#,
407        )
408    }
409
410    #[test]
411    fn each_to_for_not_available_in_body() {
412        cov_mark::check!(not_available_in_body);
413        check_assist_not_applicable(
414            convert_for_loop_with_for_each,
415            r"
416fn main() {
417    let x = vec![1, 2, 3];
418    for v in x {
419        $0v *= 2;
420    }
421}",
422        )
423    }
424
425    #[test]
426    fn each_to_for_for_borrowed() {
427        check_assist(
428            convert_for_loop_with_for_each,
429            r#"
430//- minicore: iterators
431use core::iter::{Repeat, repeat};
432
433struct S;
434impl S {
435    fn iter(&self) -> Repeat<i32> { repeat(92) }
436    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
437}
438
439fn main() {
440    let x = S;
441    for $0v in &x {
442        let a = v * 2;
443    }
444}
445"#,
446            r#"
447use core::iter::{Repeat, repeat};
448
449struct S;
450impl S {
451    fn iter(&self) -> Repeat<i32> { repeat(92) }
452    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
453}
454
455fn main() {
456    let x = S;
457    x.iter().for_each(|v| {
458        let a = v * 2;
459    });
460}
461"#,
462        )
463    }
464
465    #[test]
466    fn each_to_for_for_borrowed_no_iter_method() {
467        check_assist(
468            convert_for_loop_with_for_each,
469            r"
470struct NoIterMethod;
471fn main() {
472    let x = NoIterMethod;
473    for $0v in &x {
474        let a = v * 2;
475    }
476}
477",
478            r"
479struct NoIterMethod;
480fn main() {
481    let x = NoIterMethod;
482    (&x).into_iter().for_each(|v| {
483        let a = v * 2;
484    });
485}
486",
487        )
488    }
489
490    #[test]
491    fn each_to_for_for_borrowed_mut() {
492        check_assist(
493            convert_for_loop_with_for_each,
494            r#"
495//- minicore: iterators
496use core::iter::{Repeat, repeat};
497
498struct S;
499impl S {
500    fn iter(&self) -> Repeat<i32> { repeat(92) }
501    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
502}
503
504fn main() {
505    let x = S;
506    for $0v in &mut x {
507        let a = v * 2;
508    }
509}
510"#,
511            r#"
512use core::iter::{Repeat, repeat};
513
514struct S;
515impl S {
516    fn iter(&self) -> Repeat<i32> { repeat(92) }
517    fn iter_mut(&mut self) -> Repeat<i32> { repeat(92) }
518}
519
520fn main() {
521    let x = S;
522    x.iter_mut().for_each(|v| {
523        let a = v * 2;
524    });
525}
526"#,
527        )
528    }
529
530    #[test]
531    fn each_to_for_for_borrowed_mut_behind_var() {
532        check_assist(
533            convert_for_loop_with_for_each,
534            r"
535fn main() {
536    let x = vec![1, 2, 3];
537    let y = &mut x;
538    for $0v in y {
539        *v *= 2;
540    }
541}",
542            r"
543fn main() {
544    let x = vec![1, 2, 3];
545    let y = &mut x;
546    y.into_iter().for_each(|v| {
547        *v *= 2;
548    });
549}",
550        )
551    }
552
553    #[test]
554    fn each_to_for_already_impls_iterator() {
555        cov_mark::check!(test_already_impls_iterator);
556        check_assist(
557            convert_for_loop_with_for_each,
558            r#"
559//- minicore: iterators
560fn main() {
561    for$0 a in core::iter::repeat(92).take(1) {
562        println!("{}", a);
563    }
564}
565"#,
566            r#"
567fn main() {
568    core::iter::repeat(92).take(1).for_each(|a| {
569        println!("{}", a);
570    });
571}
572"#,
573        );
574    }
575}