Skip to main content

ide_assists/handlers/
replace_method_eager_lazy.rs

1use hir::Semantics;
2use ide_db::{RootDatabase, assists::AssistId, defs::Definition, famous_defs::FamousDefs};
3use syntax::{
4    AstNode,
5    ast::{self, Expr, HasArgList, syntax_factory::SyntaxFactory},
6};
7
8use crate::{AssistContext, Assists, utils::wrap_paren_in_call};
9
10// Assist: replace_with_lazy_method
11//
12// Replace `unwrap_or` with `unwrap_or_else` and `ok_or` with `ok_or_else`.
13//
14// ```
15// # //- minicore:option, fn
16// fn foo() {
17//     let a = Some(1);
18//     a.unwra$0p_or(2);
19// }
20// ```
21// ->
22// ```
23// fn foo() {
24//     let a = Some(1);
25//     a.unwrap_or_else(|| 2);
26// }
27// ```
28pub(crate) fn replace_with_lazy_method(
29    acc: &mut Assists,
30    ctx: &AssistContext<'_, '_>,
31) -> Option<()> {
32    let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
33    let scope = ctx.sema.scope(call.syntax())?;
34
35    let last_arg = call.arg_list()?.args().next()?;
36    let method_name = call.name_ref()?;
37
38    let callable = ctx.sema.resolve_method_call_as_callable(&call)?;
39    let (_, receiver_ty) = callable.receiver_param(ctx.sema.db)?;
40    let n_params = callable.n_params() + 1;
41
42    let method_name_lazy = lazy_method_name(&method_name.text());
43
44    receiver_ty.iterate_method_candidates_with_traits(
45        ctx.sema.db,
46        &scope,
47        &scope.visible_traits().0,
48        None,
49        |func| {
50            let valid = func.name(ctx.sema.db).as_str() == &*method_name_lazy
51                && func.num_params(ctx.sema.db) == n_params
52                && {
53                    let params = func.params_without_self(ctx.sema.db);
54                    let last_p = params.first()?;
55                    // FIXME: Check that this has the form of `() -> T` where T is the current type of the argument
56                    last_p.ty().impls_fnonce(ctx.sema.db)
57                };
58            valid.then_some(func)
59        },
60    )?;
61
62    acc.add(
63        AssistId::refactor_rewrite("replace_with_lazy_method"),
64        format!("Replace {method_name} with {method_name_lazy}"),
65        call.syntax().text_range(),
66        |builder| {
67            let editor = builder.make_editor(call.syntax());
68            let add_param = match &*method_name_lazy {
69                "and_then" => true,
70                "or_else" | "unwrap_or_else" => {
71                    FamousDefs(&ctx.sema, scope.krate()).core_result_Result().is_some_and(
72                        |result| result.ty(ctx.db()).could_unify_with(ctx.db(), &receiver_ty),
73                    )
74                }
75                _ => false,
76            };
77            let closured = into_closure(&last_arg, add_param, editor.make());
78            editor.replace(method_name.syntax(), editor.make().name(&method_name_lazy).syntax());
79            editor.replace(last_arg.syntax(), closured.syntax());
80            if let Some(cap) = ctx.config.snippet_cap
81                && let ast::Expr::ClosureExpr(closured) = closured
82                && let Some(param) = closured.param_list().and_then(|it| it.params().next())
83            {
84                editor.add_annotation(param.syntax(), builder.make_placeholder_snippet(cap));
85            }
86            builder.add_file_edits(ctx.vfs_file_id(), editor);
87        },
88    )
89}
90
91fn lazy_method_name(name: &str) -> String {
92    if ends_is(name, "or") {
93        format!("{name}_else")
94    } else if ends_is(name, "and") {
95        format!("{name}_then")
96    } else if ends_is(name, "then_some") {
97        name.strip_suffix("_some").unwrap().to_owned()
98    } else {
99        format!("{name}_with")
100    }
101}
102
103fn into_closure(param: &Expr, add_param: bool, make: &SyntaxFactory) -> Expr {
104    (|| {
105        if let ast::Expr::CallExpr(call) = param {
106            if call.arg_list()?.args().count() == 0 { Some(call.expr()?) } else { None }
107        } else {
108            None
109        }
110    })()
111    .unwrap_or_else(|| {
112        let pats = add_param.then(|| make.untyped_param(make.wildcard_pat().into()));
113        make.expr_closure(pats, param.clone()).into()
114    })
115}
116
117// Assist: replace_with_eager_method
118//
119// Replace `unwrap_or_else` with `unwrap_or` and `ok_or_else` with `ok_or`.
120//
121// ```
122// # //- minicore:option, fn
123// fn foo() {
124//     let a = Some(1);
125//     a.unwra$0p_or_else(|| 2);
126// }
127// ```
128// ->
129// ```
130// fn foo() {
131//     let a = Some(1);
132//     a.unwrap_or(2);
133// }
134// ```
135pub(crate) fn replace_with_eager_method(
136    acc: &mut Assists,
137    ctx: &AssistContext<'_, '_>,
138) -> Option<()> {
139    let call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
140    let scope = ctx.sema.scope(call.syntax())?;
141
142    let last_arg = call.arg_list()?.args().next()?;
143    let method_name = call.name_ref()?;
144
145    let callable = ctx.sema.resolve_method_call_as_callable(&call)?;
146    let (_, receiver_ty) = callable.receiver_param(ctx.sema.db)?;
147    let n_params = callable.n_params() + 1;
148    let params = callable.params();
149
150    // FIXME: Check that the arg is of the form `() -> T`
151    if !params.first()?.ty().impls_fnonce(ctx.sema.db) {
152        return None;
153    }
154
155    let method_name_text = method_name.text();
156    let method_name_eager = eager_method_name(&method_name_text)?;
157
158    receiver_ty.iterate_method_candidates_with_traits(
159        ctx.sema.db,
160        &scope,
161        &scope.visible_traits().0,
162        None,
163        |func| {
164            let valid = func.name(ctx.sema.db).as_str() == method_name_eager
165                && func.num_params(ctx.sema.db) == n_params;
166            valid.then_some(func)
167        },
168    )?;
169
170    acc.add(
171        AssistId::refactor_rewrite("replace_with_eager_method"),
172        format!("Replace {method_name} with {method_name_eager}"),
173        call.syntax().text_range(),
174        |builder| {
175            let editor = builder.make_editor(call.syntax());
176            let called = into_call(&last_arg, &ctx.sema, editor.make());
177            editor.replace(method_name.syntax(), editor.make().name(method_name_eager).syntax());
178            editor.replace(last_arg.syntax(), called.syntax());
179            builder.add_file_edits(ctx.vfs_file_id(), editor);
180        },
181    )
182}
183
184fn into_call(param: &Expr, sema: &Semantics<'_, RootDatabase>, make: &SyntaxFactory) -> Expr {
185    (|| {
186        if let ast::Expr::ClosureExpr(closure) = param {
187            let mut params = closure.param_list()?.params();
188            match params.next() {
189                Some(_) if params.next().is_none() => {
190                    let params = sema.resolve_expr_as_callable(param)?.params();
191                    let used_param = Definition::Local(params.first()?.as_local(sema.db)?)
192                        .usages(sema)
193                        .at_least_one();
194                    if used_param { None } else { Some(closure.body()?) }
195                }
196                None => Some(closure.body()?),
197                Some(_) => None,
198            }
199        } else {
200            None
201        }
202    })()
203    .unwrap_or_else(|| {
204        let callable = wrap_paren_in_call(param.clone(), make);
205        make.expr_call(callable, make.arg_list(Vec::new())).into()
206    })
207}
208
209fn eager_method_name(name: &str) -> Option<&str> {
210    if name == "then" {
211        return Some("then_some");
212    }
213
214    name.strip_suffix("_else")
215        .or_else(|| name.strip_suffix("_then"))
216        .or_else(|| name.strip_suffix("_with"))
217}
218
219fn ends_is(name: &str, end: &str) -> bool {
220    name.strip_suffix(end).is_some_and(|s| s.is_empty() || s.ends_with('_'))
221}
222
223#[cfg(test)]
224mod tests {
225    use crate::tests::check_assist;
226
227    use super::*;
228
229    #[test]
230    fn replace_or_with_or_else_simple() {
231        check_assist(
232            replace_with_lazy_method,
233            r#"
234//- minicore: option, result, fn
235fn foo() {
236    let foo = Some(1);
237    return foo.unwrap_$0or(2);
238}
239"#,
240            r#"
241fn foo() {
242    let foo = Some(1);
243    return foo.unwrap_or_else(|| 2);
244}
245"#,
246        )
247    }
248
249    #[test]
250    fn replace_or_with_or_else_with_parameter() {
251        check_assist(
252            replace_with_lazy_method,
253            r#"
254//- minicore: option, result, fn
255fn foo() {
256    let foo = Ok(1);
257    return foo.unwrap_$0or(2);
258}
259"#,
260            r#"
261fn foo() {
262    let foo = Ok(1);
263    return foo.unwrap_or_else(|${0:_}| 2);
264}
265"#,
266        )
267    }
268
269    #[test]
270    fn replace_or_with_or_else_call() {
271        check_assist(
272            replace_with_lazy_method,
273            r#"
274//- minicore: option, fn
275fn foo() {
276    let foo = Some(1);
277    return foo.unwrap_$0or(x());
278}
279"#,
280            r#"
281fn foo() {
282    let foo = Some(1);
283    return foo.unwrap_or_else(x);
284}
285"#,
286        )
287    }
288
289    #[test]
290    fn replace_or_with_or_else_block() {
291        check_assist(
292            replace_with_lazy_method,
293            r#"
294//- minicore: option, fn
295fn foo() {
296    let foo = Some(1);
297    return foo.unwrap_$0or({
298        let mut x = bar();
299        for i in 0..10 {
300            x += i;
301        }
302        x
303    });
304}
305"#,
306            r#"
307fn foo() {
308    let foo = Some(1);
309    return foo.unwrap_or_else(|| {
310        let mut x = bar();
311        for i in 0..10 {
312            x += i;
313        }
314        x
315    });
316}
317"#,
318        )
319    }
320
321    #[test]
322    fn replace_or_else_with_or_simple() {
323        check_assist(
324            replace_with_eager_method,
325            r#"
326//- minicore: option, fn
327fn foo() {
328    let foo = Some(1);
329    return foo.unwrap_$0or_else(|| 2);
330}
331"#,
332            r#"
333fn foo() {
334    let foo = Some(1);
335    return foo.unwrap_or(2);
336}
337"#,
338        )
339    }
340
341    #[test]
342    fn replace_or_else_with_or_call() {
343        check_assist(
344            replace_with_eager_method,
345            r#"
346//- minicore: option, fn
347fn foo() {
348    let foo = Some(1);
349    return foo.unwrap_$0or_else(x);
350}
351
352fn x() -> i32 { 0 }
353"#,
354            r#"
355fn foo() {
356    let foo = Some(1);
357    return foo.unwrap_or(x());
358}
359
360fn x() -> i32 { 0 }
361"#,
362        )
363    }
364
365    #[test]
366    fn replace_or_else_with_or_map() {
367        check_assist(
368            replace_with_eager_method,
369            r#"
370//- minicore: option, fn
371fn foo() {
372    let foo = Some("foo");
373    return foo.map$0_or_else(|| 42, |v| v.len());
374}
375"#,
376            r#"
377fn foo() {
378    let foo = Some("foo");
379    return foo.map_or(42, |v| v.len());
380}
381"#,
382        )
383    }
384
385    #[test]
386    fn replace_and_with_and_then() {
387        check_assist(
388            replace_with_lazy_method,
389            r#"
390//- minicore: option, fn
391fn foo() {
392    let foo = Some("foo");
393    return foo.and$0(Some("bar"));
394}
395"#,
396            r#"
397fn foo() {
398    let foo = Some("foo");
399    return foo.and_then(|${0:_}| Some("bar"));
400}
401"#,
402        )
403    }
404
405    #[test]
406    fn replace_and_then_with_and() {
407        check_assist(
408            replace_with_eager_method,
409            r#"
410//- minicore: option, fn
411fn foo() {
412    let foo = Some("foo");
413    return foo.and_then$0(|it| Some("bar"));
414}
415"#,
416            r#"
417fn foo() {
418    let foo = Some("foo");
419    return foo.and(Some("bar"));
420}
421"#,
422        )
423    }
424
425    #[test]
426    fn replace_and_then_with_and_used_param() {
427        check_assist(
428            replace_with_eager_method,
429            r#"
430//- minicore: option, fn
431fn foo() {
432    let foo = Some("foo");
433    return foo.and_then$0(|it| Some(it.strip_suffix("bar")));
434}
435"#,
436            r#"
437fn foo() {
438    let foo = Some("foo");
439    return foo.and((|it| Some(it.strip_suffix("bar")))());
440}
441"#,
442        )
443    }
444
445    #[test]
446    fn replace_then_some_with_then() {
447        check_assist(
448            replace_with_lazy_method,
449            r#"
450//- minicore: option, fn, bool_impl
451fn foo() {
452    let foo = true;
453    let x = foo.then_some$0(2);
454}
455"#,
456            r#"
457fn foo() {
458    let foo = true;
459    let x = foo.then(|| 2);
460}
461"#,
462        )
463    }
464
465    #[test]
466    fn replace_then_with_then_some() {
467        check_assist(
468            replace_with_eager_method,
469            r#"
470//- minicore: option, fn, bool_impl
471fn foo() {
472    let foo = true;
473    let x = foo.then$0(|| 2);
474}
475"#,
476            r#"
477fn foo() {
478    let foo = true;
479    let x = foo.then_some(2);
480}
481"#,
482        )
483    }
484
485    #[test]
486    fn replace_then_with_then_some_needs_parens() {
487        check_assist(
488            replace_with_eager_method,
489            r#"
490//- minicore: option, fn, bool_impl
491struct Func { f: fn() -> i32 }
492fn foo() {
493    let foo = true;
494    let func = Func { f: || 2 };
495    let x = foo.then$0(func.f);
496}
497"#,
498            r#"
499struct Func { f: fn() -> i32 }
500fn foo() {
501    let foo = true;
502    let func = Func { f: || 2 };
503    let x = foo.then_some((func.f)());
504}
505"#,
506        )
507    }
508}