ide_assists/handlers/
replace_method_eager_lazy.rs

1use hir::Semantics;
2use ide_db::{RootDatabase, assists::AssistId, defs::Definition};
3use syntax::{
4    AstNode,
5    ast::{self, Expr, HasArgList, make},
6};
7
8use crate::{AssistContext, Assists};
9
10// Assist: replace_with_lazy_method
11//
12// Replace `unwrap_or` with `unwrap_or_else` and `ok_or` with `ok_or_else`.
13//
14// ```
15// # //- minicore:option, fn
16// fn foo() {
17//     let a = Some(1);
18//     a.unwra$0p_or(2);
19// }
20// ```
21// ->
22// ```
23// fn foo() {
24//     let a = Some(1);
25//     a.unwrap_or_else(|| 2);
26// }
27// ```
28pub(crate) fn replace_with_lazy_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
29    let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
30    let scope = ctx.sema.scope(call.syntax())?;
31
32    let last_arg = call.arg_list()?.args().next()?;
33    let method_name = call.name_ref()?;
34
35    let callable = ctx.sema.resolve_method_call_as_callable(&call)?;
36    let (_, receiver_ty) = callable.receiver_param(ctx.sema.db)?;
37    let n_params = callable.n_params() + 1;
38
39    let method_name_lazy = lazy_method_name(&method_name.text());
40
41    receiver_ty.iterate_method_candidates_with_traits(
42        ctx.sema.db,
43        &scope,
44        &scope.visible_traits().0,
45        None,
46        |func| {
47            let valid = func.name(ctx.sema.db).as_str() == &*method_name_lazy
48                && func.num_params(ctx.sema.db) == n_params
49                && {
50                    let params = func.params_without_self(ctx.sema.db);
51                    let last_p = params.first()?;
52                    // FIXME: Check that this has the form of `() -> T` where T is the current type of the argument
53                    last_p.ty().impls_fnonce(ctx.sema.db)
54                };
55            valid.then_some(func)
56        },
57    )?;
58
59    acc.add(
60        AssistId::refactor_rewrite("replace_with_lazy_method"),
61        format!("Replace {method_name} with {method_name_lazy}"),
62        call.syntax().text_range(),
63        |builder| {
64            let closured = into_closure(&last_arg, &method_name_lazy);
65            builder.replace(method_name.syntax().text_range(), method_name_lazy);
66            builder.replace_ast(last_arg, closured);
67        },
68    )
69}
70
71fn lazy_method_name(name: &str) -> String {
72    if ends_is(name, "or") {
73        format!("{name}_else")
74    } else if ends_is(name, "and") {
75        format!("{name}_then")
76    } else if ends_is(name, "then_some") {
77        name.strip_suffix("_some").unwrap().to_owned()
78    } else {
79        format!("{name}_with")
80    }
81}
82
83fn into_closure(param: &Expr, name_lazy: &str) -> Expr {
84    (|| {
85        if let ast::Expr::CallExpr(call) = param {
86            if call.arg_list()?.args().count() == 0 { Some(call.expr()?) } else { None }
87        } else {
88            None
89        }
90    })()
91    .unwrap_or_else(|| {
92        let pats = (name_lazy == "and_then")
93            .then(|| make::untyped_param(make::ext::simple_ident_pat(make::name("it")).into()));
94        make::expr_closure(pats, param.clone()).into()
95    })
96}
97
98// Assist: replace_with_eager_method
99//
100// Replace `unwrap_or_else` with `unwrap_or` and `ok_or_else` with `ok_or`.
101//
102// ```
103// # //- minicore:option, fn
104// fn foo() {
105//     let a = Some(1);
106//     a.unwra$0p_or_else(|| 2);
107// }
108// ```
109// ->
110// ```
111// fn foo() {
112//     let a = Some(1);
113//     a.unwrap_or(2);
114// }
115// ```
116pub(crate) fn replace_with_eager_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
117    let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
118    let scope = ctx.sema.scope(call.syntax())?;
119
120    let last_arg = call.arg_list()?.args().next()?;
121    let method_name = call.name_ref()?;
122
123    let callable = ctx.sema.resolve_method_call_as_callable(&call)?;
124    let (_, receiver_ty) = callable.receiver_param(ctx.sema.db)?;
125    let n_params = callable.n_params() + 1;
126    let params = callable.params();
127
128    // FIXME: Check that the arg is of the form `() -> T`
129    if !params.first()?.ty().impls_fnonce(ctx.sema.db) {
130        return None;
131    }
132
133    let method_name_text = method_name.text();
134    let method_name_eager = eager_method_name(&method_name_text)?;
135
136    receiver_ty.iterate_method_candidates_with_traits(
137        ctx.sema.db,
138        &scope,
139        &scope.visible_traits().0,
140        None,
141        |func| {
142            let valid = func.name(ctx.sema.db).as_str() == method_name_eager
143                && func.num_params(ctx.sema.db) == n_params;
144            valid.then_some(func)
145        },
146    )?;
147
148    acc.add(
149        AssistId::refactor_rewrite("replace_with_eager_method"),
150        format!("Replace {method_name} with {method_name_eager}"),
151        call.syntax().text_range(),
152        |builder| {
153            builder.replace(method_name.syntax().text_range(), method_name_eager);
154            let called = into_call(&last_arg, &ctx.sema);
155            builder.replace_ast(last_arg, called);
156        },
157    )
158}
159
160fn into_call(param: &Expr, sema: &Semantics<'_, RootDatabase>) -> Expr {
161    (|| {
162        if let ast::Expr::ClosureExpr(closure) = param {
163            let mut params = closure.param_list()?.params();
164            match params.next() {
165                Some(_) if params.next().is_none() => {
166                    let params = sema.resolve_expr_as_callable(param)?.params();
167                    let used_param = Definition::Local(params.first()?.as_local(sema.db)?)
168                        .usages(sema)
169                        .at_least_one();
170                    if used_param { None } else { Some(closure.body()?) }
171                }
172                None => Some(closure.body()?),
173                Some(_) => None,
174            }
175        } else {
176            None
177        }
178    })()
179    .unwrap_or_else(|| {
180        let callable = if needs_parens_in_call(param) {
181            make::expr_paren(param.clone()).into()
182        } else {
183            param.clone()
184        };
185        make::expr_call(callable, make::arg_list(Vec::new())).into()
186    })
187}
188
189fn eager_method_name(name: &str) -> Option<&str> {
190    if name == "then" {
191        return Some("then_some");
192    }
193
194    name.strip_suffix("_else")
195        .or_else(|| name.strip_suffix("_then"))
196        .or_else(|| name.strip_suffix("_with"))
197}
198
199fn ends_is(name: &str, end: &str) -> bool {
200    name.strip_suffix(end).is_some_and(|s| s.is_empty() || s.ends_with('_'))
201}
202
203fn needs_parens_in_call(param: &Expr) -> bool {
204    let call = make::expr_call(make::ext::expr_unit(), make::arg_list(Vec::new()));
205    let callable = call.expr().expect("invalid make call");
206    param.needs_parens_in_place_of(call.syntax(), callable.syntax())
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::tests::check_assist;
212
213    use super::*;
214
215    #[test]
216    fn replace_or_with_or_else_simple() {
217        check_assist(
218            replace_with_lazy_method,
219            r#"
220//- minicore: option, fn
221fn foo() {
222    let foo = Some(1);
223    return foo.unwrap_$0or(2);
224}
225"#,
226            r#"
227fn foo() {
228    let foo = Some(1);
229    return foo.unwrap_or_else(|| 2);
230}
231"#,
232        )
233    }
234
235    #[test]
236    fn replace_or_with_or_else_call() {
237        check_assist(
238            replace_with_lazy_method,
239            r#"
240//- minicore: option, fn
241fn foo() {
242    let foo = Some(1);
243    return foo.unwrap_$0or(x());
244}
245"#,
246            r#"
247fn foo() {
248    let foo = Some(1);
249    return foo.unwrap_or_else(x);
250}
251"#,
252        )
253    }
254
255    #[test]
256    fn replace_or_with_or_else_block() {
257        check_assist(
258            replace_with_lazy_method,
259            r#"
260//- minicore: option, fn
261fn foo() {
262    let foo = Some(1);
263    return foo.unwrap_$0or({
264        let mut x = bar();
265        for i in 0..10 {
266            x += i;
267        }
268        x
269    });
270}
271"#,
272            r#"
273fn foo() {
274    let foo = Some(1);
275    return foo.unwrap_or_else(|| {
276        let mut x = bar();
277        for i in 0..10 {
278            x += i;
279        }
280        x
281    });
282}
283"#,
284        )
285    }
286
287    #[test]
288    fn replace_or_else_with_or_simple() {
289        check_assist(
290            replace_with_eager_method,
291            r#"
292//- minicore: option, fn
293fn foo() {
294    let foo = Some(1);
295    return foo.unwrap_$0or_else(|| 2);
296}
297"#,
298            r#"
299fn foo() {
300    let foo = Some(1);
301    return foo.unwrap_or(2);
302}
303"#,
304        )
305    }
306
307    #[test]
308    fn replace_or_else_with_or_call() {
309        check_assist(
310            replace_with_eager_method,
311            r#"
312//- minicore: option, fn
313fn foo() {
314    let foo = Some(1);
315    return foo.unwrap_$0or_else(x);
316}
317
318fn x() -> i32 { 0 }
319"#,
320            r#"
321fn foo() {
322    let foo = Some(1);
323    return foo.unwrap_or(x());
324}
325
326fn x() -> i32 { 0 }
327"#,
328        )
329    }
330
331    #[test]
332    fn replace_or_else_with_or_map() {
333        check_assist(
334            replace_with_eager_method,
335            r#"
336//- minicore: option, fn
337fn foo() {
338    let foo = Some("foo");
339    return foo.map$0_or_else(|| 42, |v| v.len());
340}
341"#,
342            r#"
343fn foo() {
344    let foo = Some("foo");
345    return foo.map_or(42, |v| v.len());
346}
347"#,
348        )
349    }
350
351    #[test]
352    fn replace_and_with_and_then() {
353        check_assist(
354            replace_with_lazy_method,
355            r#"
356//- minicore: option, fn
357fn foo() {
358    let foo = Some("foo");
359    return foo.and$0(Some("bar"));
360}
361"#,
362            r#"
363fn foo() {
364    let foo = Some("foo");
365    return foo.and_then(|it| Some("bar"));
366}
367"#,
368        )
369    }
370
371    #[test]
372    fn replace_and_then_with_and() {
373        check_assist(
374            replace_with_eager_method,
375            r#"
376//- minicore: option, fn
377fn foo() {
378    let foo = Some("foo");
379    return foo.and_then$0(|it| Some("bar"));
380}
381"#,
382            r#"
383fn foo() {
384    let foo = Some("foo");
385    return foo.and(Some("bar"));
386}
387"#,
388        )
389    }
390
391    #[test]
392    fn replace_and_then_with_and_used_param() {
393        check_assist(
394            replace_with_eager_method,
395            r#"
396//- minicore: option, fn
397fn foo() {
398    let foo = Some("foo");
399    return foo.and_then$0(|it| Some(it.strip_suffix("bar")));
400}
401"#,
402            r#"
403fn foo() {
404    let foo = Some("foo");
405    return foo.and((|it| Some(it.strip_suffix("bar")))());
406}
407"#,
408        )
409    }
410
411    #[test]
412    fn replace_then_some_with_then() {
413        check_assist(
414            replace_with_lazy_method,
415            r#"
416//- minicore: option, fn, bool_impl
417fn foo() {
418    let foo = true;
419    let x = foo.then_some$0(2);
420}
421"#,
422            r#"
423fn foo() {
424    let foo = true;
425    let x = foo.then(|| 2);
426}
427"#,
428        )
429    }
430
431    #[test]
432    fn replace_then_with_then_some() {
433        check_assist(
434            replace_with_eager_method,
435            r#"
436//- minicore: option, fn, bool_impl
437fn foo() {
438    let foo = true;
439    let x = foo.then$0(|| 2);
440}
441"#,
442            r#"
443fn foo() {
444    let foo = true;
445    let x = foo.then_some(2);
446}
447"#,
448        )
449    }
450
451    #[test]
452    fn replace_then_with_then_some_needs_parens() {
453        check_assist(
454            replace_with_eager_method,
455            r#"
456//- minicore: option, fn, bool_impl
457struct Func { f: fn() -> i32 }
458fn foo() {
459    let foo = true;
460    let func = Func { f: || 2 };
461    let x = foo.then$0(func.f);
462}
463"#,
464            r#"
465struct Func { f: fn() -> i32 }
466fn foo() {
467    let foo = true;
468    let func = Func { f: || 2 };
469    let x = foo.then_some((func.f)());
470}
471"#,
472        )
473    }
474}