ide_assists/handlers/
convert_bool_then.rs

1use hir::{AsAssocItem, Semantics, sym};
2use ide_db::{
3    RootDatabase,
4    famous_defs::FamousDefs,
5    syntax_helpers::node_ext::{
6        block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
7    },
8};
9use itertools::Itertools;
10use syntax::{
11    AstNode, SyntaxNode,
12    ast::{self, HasArgList, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
13    syntax_editor::SyntaxEditor,
14};
15
16use crate::{
17    AssistContext, AssistId, Assists,
18    utils::{invert_boolean_expression, unwrap_trivial_block},
19};
20
21// Assist: convert_if_to_bool_then
22//
23// Converts an if expression into a corresponding `bool::then` call.
24//
25// ```
26// # //- minicore: option
27// fn main() {
28//     if$0 cond {
29//         Some(val)
30//     } else {
31//         None
32//     }
33// }
34// ```
35// ->
36// ```
37// fn main() {
38//     cond.then(|| val)
39// }
40// ```
41pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
42    // FIXME applies to match as well
43    let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
44    if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
45        return None;
46    }
47
48    let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
49    let then = expr.then_branch()?;
50    let else_ = match expr.else_branch()? {
51        ast::ElseBranch::Block(b) => b,
52        ast::ElseBranch::IfExpr(_) => {
53            cov_mark::hit!(convert_if_to_bool_then_chain);
54            return None;
55        }
56    };
57
58    let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
59
60    let (invert_cond, closure_body) = match (
61        block_is_none_variant(&ctx.sema, &then, none_variant),
62        block_is_none_variant(&ctx.sema, &else_, none_variant),
63    ) {
64        (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
65        (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
66        _ => return None,
67    };
68
69    if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
70        cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
71        return None;
72    }
73
74    let target = expr.syntax().text_range();
75    acc.add(
76        AssistId::refactor_rewrite("convert_if_to_bool_then"),
77        "Convert `if` expression to `bool::then` call",
78        target,
79        |builder| {
80            let closure_body = closure_body.clone_subtree();
81            let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
82            // Rewrite all `Some(e)` in tail position to `e`
83            for_each_tail_expr(&closure_body, &mut |e| {
84                let e = match e {
85                    ast::Expr::BreakExpr(e) => e.expr(),
86                    e @ ast::Expr::CallExpr(_) => Some(e.clone()),
87                    _ => None,
88                };
89                if let Some(ast::Expr::CallExpr(call)) = e
90                    && let Some(arg_list) = call.arg_list()
91                    && let Some(arg) = arg_list.args().next()
92                {
93                    editor.replace(call.syntax(), arg.syntax());
94                }
95            });
96            let edit = editor.finish();
97            let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap();
98
99            let mut editor = builder.make_editor(expr.syntax());
100            let make = SyntaxFactory::with_mappings();
101            let closure_body = match closure_body {
102                ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
103                e => e,
104            };
105
106            let parenthesize = matches!(
107                cond,
108                ast::Expr::BinExpr(_)
109                    | ast::Expr::BlockExpr(_)
110                    | ast::Expr::BreakExpr(_)
111                    | ast::Expr::CastExpr(_)
112                    | ast::Expr::ClosureExpr(_)
113                    | ast::Expr::ContinueExpr(_)
114                    | ast::Expr::ForExpr(_)
115                    | ast::Expr::IfExpr(_)
116                    | ast::Expr::LoopExpr(_)
117                    | ast::Expr::MacroExpr(_)
118                    | ast::Expr::MatchExpr(_)
119                    | ast::Expr::PrefixExpr(_)
120                    | ast::Expr::RangeExpr(_)
121                    | ast::Expr::RefExpr(_)
122                    | ast::Expr::ReturnExpr(_)
123                    | ast::Expr::WhileExpr(_)
124                    | ast::Expr::YieldExpr(_)
125            );
126            let cond = if invert_cond {
127                invert_boolean_expression(&make, cond)
128            } else {
129                cond.clone_for_update()
130            };
131            let cond = if parenthesize { make.expr_paren(cond).into() } else { cond };
132            let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into()));
133            let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list);
134            editor.replace(expr.syntax(), mcall.syntax());
135
136            editor.add_mappings(make.finish_with_mappings());
137            builder.add_file_edits(ctx.vfs_file_id(), editor);
138        },
139    )
140}
141
142// Assist: convert_bool_then_to_if
143//
144// Converts a `bool::then` method call to an equivalent if expression.
145//
146// ```
147// # //- minicore: bool_impl
148// fn main() {
149//     (0 == 0).then$0(|| val)
150// }
151// ```
152// ->
153// ```
154// fn main() {
155//     if 0 == 0 {
156//         Some(val)
157//     } else {
158//         None
159//     }
160// }
161// ```
162pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
163    let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
164    let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
165    let receiver = mcall.receiver()?;
166    // FIXME: rewrite in terms of `#![feature(exact_length_collection)]`. See: #149266
167    let closure_body = Itertools::exactly_one(mcall.arg_list()?.args()).ok()?;
168    let closure_body = match closure_body {
169        ast::Expr::ClosureExpr(expr) => expr.body()?,
170        _ => return None,
171    };
172    // Verify this is `bool::then` that is being called.
173    let func = ctx.sema.resolve_method_call(&mcall)?;
174    if func.name(ctx.sema.db) != sym::then {
175        return None;
176    }
177    let assoc = func.as_assoc_item(ctx.sema.db)?;
178    if !assoc.implementing_ty(ctx.sema.db)?.is_bool() {
179        return None;
180    }
181
182    let target = mcall.syntax().text_range();
183    acc.add(
184        AssistId::refactor_rewrite("convert_bool_then_to_if"),
185        "Convert `bool::then` call to `if`",
186        target,
187        |builder| {
188            let mapless_make = SyntaxFactory::without_mappings();
189            let closure_body = match closure_body.reset_indent() {
190                ast::Expr::BlockExpr(block) => block,
191                e => mapless_make.block_expr(None, Some(e)),
192            };
193
194            let closure_body = closure_body.clone_subtree();
195            let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
196            // Wrap all tails in `Some(...)`
197            let none_path = mapless_make.expr_path(mapless_make.ident_path("None"));
198            let some_path = mapless_make.expr_path(mapless_make.ident_path("Some"));
199            for_each_tail_expr(&ast::Expr::BlockExpr(closure_body), &mut |e| {
200                let e = match e {
201                    ast::Expr::BreakExpr(e) => e.expr(),
202                    ast::Expr::ReturnExpr(e) => e.expr(),
203                    _ => Some(e.clone()),
204                };
205                if let Some(expr) = e {
206                    editor.replace(
207                        expr.syntax().clone(),
208                        mapless_make
209                            .expr_call(some_path.clone(), mapless_make.arg_list(Some(expr)))
210                            .syntax()
211                            .clone(),
212                    );
213                }
214            });
215            let edit = editor.finish();
216            let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap();
217
218            let mut editor = builder.make_editor(mcall.syntax());
219            let make = SyntaxFactory::with_mappings();
220
221            let cond = match &receiver {
222                ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
223                _ => receiver,
224            };
225            let if_expr = make
226                .expr_if(
227                    cond,
228                    closure_body,
229                    Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))),
230                )
231                .indent(mcall.indent_level());
232            editor.replace(mcall.syntax().clone(), if_expr.syntax().clone());
233
234            editor.add_mappings(make.finish_with_mappings());
235            builder.add_file_edits(ctx.vfs_file_id(), editor);
236        },
237    )
238}
239
240fn option_variants(
241    sema: &Semantics<'_, RootDatabase>,
242    expr: &SyntaxNode,
243) -> Option<(hir::Variant, hir::Variant)> {
244    let fam = FamousDefs(sema, sema.scope(expr)?.krate());
245    let option_variants = fam.core_option_Option()?.variants(sema.db);
246    match &*option_variants {
247        &[variant0, variant1] => Some(if variant0.name(sema.db) == sym::None {
248            (variant0, variant1)
249        } else {
250            (variant1, variant0)
251        }),
252        _ => None,
253    }
254}
255
256/// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
257/// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
258fn is_invalid_body(
259    sema: &Semantics<'_, RootDatabase>,
260    some_variant: hir::Variant,
261    expr: &ast::Expr,
262) -> bool {
263    let mut invalid = false;
264    preorder_expr(expr, &mut |e| {
265        invalid |=
266            matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
267        invalid
268    });
269    if !invalid {
270        for_each_tail_expr(expr, &mut |e| {
271            if invalid {
272                return;
273            }
274            let e = match e {
275                ast::Expr::BreakExpr(e) => e.expr(),
276                e @ ast::Expr::CallExpr(_) => Some(e.clone()),
277                _ => None,
278            };
279            if let Some(ast::Expr::CallExpr(call)) = e
280                && let Some(ast::Expr::PathExpr(p)) = call.expr()
281            {
282                let res = p.path().and_then(|p| sema.resolve_path(&p));
283                if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
284                    return invalid |= v != some_variant;
285                }
286            }
287            invalid = true
288        });
289    }
290    invalid
291}
292
293fn block_is_none_variant(
294    sema: &Semantics<'_, RootDatabase>,
295    block: &ast::BlockExpr,
296    none_variant: hir::Variant,
297) -> bool {
298    block_as_lone_tail(block).and_then(|e| match e {
299        ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
300            hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
301            _ => None,
302        },
303        _ => None,
304    }) == Some(none_variant)
305}
306
307#[cfg(test)]
308mod tests {
309    use crate::tests::{check_assist, check_assist_not_applicable};
310
311    use super::*;
312
313    #[test]
314    fn convert_if_to_bool_then_simple() {
315        check_assist(
316            convert_if_to_bool_then,
317            r"
318//- minicore:option
319fn main() {
320    if$0 true {
321        Some(15)
322    } else {
323        None
324    }
325}
326",
327            r"
328fn main() {
329    true.then(|| 15)
330}
331",
332        );
333    }
334
335    #[test]
336    fn convert_if_to_bool_then_invert() {
337        check_assist(
338            convert_if_to_bool_then,
339            r"
340//- minicore:option
341fn main() {
342    if$0 true {
343        None
344    } else {
345        Some(15)
346    }
347}
348",
349            r"
350fn main() {
351    false.then(|| 15)
352}
353",
354        );
355    }
356
357    #[test]
358    fn convert_if_to_bool_then_none_none() {
359        check_assist_not_applicable(
360            convert_if_to_bool_then,
361            r"
362//- minicore:option
363fn main() {
364    if$0 true {
365        None
366    } else {
367        None
368    }
369}
370",
371        );
372    }
373
374    #[test]
375    fn convert_if_to_bool_then_some_some() {
376        check_assist_not_applicable(
377            convert_if_to_bool_then,
378            r"
379//- minicore:option
380fn main() {
381    if$0 true {
382        Some(15)
383    } else {
384        Some(15)
385    }
386}
387",
388        );
389    }
390
391    #[test]
392    fn convert_if_to_bool_then_mixed() {
393        check_assist_not_applicable(
394            convert_if_to_bool_then,
395            r"
396//- minicore:option
397fn main() {
398    if$0 true {
399        if true {
400            Some(15)
401        } else {
402            None
403        }
404    } else {
405        None
406    }
407}
408",
409        );
410    }
411
412    #[test]
413    fn convert_if_to_bool_then_chain() {
414        cov_mark::check!(convert_if_to_bool_then_chain);
415        check_assist_not_applicable(
416            convert_if_to_bool_then,
417            r"
418//- minicore:option
419fn main() {
420    if$0 true {
421        Some(15)
422    } else if true {
423        None
424    } else {
425        None
426    }
427}
428",
429        );
430    }
431
432    #[test]
433    fn convert_if_to_bool_then_pattern_cond() {
434        check_assist_not_applicable(
435            convert_if_to_bool_then,
436            r"
437//- minicore:option
438fn main() {
439    if$0 let true = true {
440        Some(15)
441    } else {
442        None
443    }
444}
445",
446        );
447    }
448
449    #[test]
450    fn convert_if_to_bool_then_pattern_invalid_body() {
451        cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
452        check_assist_not_applicable(
453            convert_if_to_bool_then,
454            r"
455//- minicore:option
456fn make_me_an_option() -> Option<i32> { None }
457fn main() {
458    if$0 true {
459        if true {
460            make_me_an_option()
461        } else {
462            Some(15)
463        }
464    } else {
465        None
466    }
467}
468",
469        );
470        check_assist_not_applicable(
471            convert_if_to_bool_then,
472            r"
473//- minicore:option
474fn main() {
475    if$0 true {
476        if true {
477            return;
478        }
479        Some(15)
480    } else {
481        None
482    }
483}
484",
485        );
486    }
487
488    #[test]
489    fn convert_bool_then_to_if_inapplicable() {
490        check_assist_not_applicable(
491            convert_bool_then_to_if,
492            r"
493//- minicore:bool_impl
494fn main() {
495    0.t$0hen(|| 15);
496}
497",
498        );
499        check_assist_not_applicable(
500            convert_bool_then_to_if,
501            r"
502//- minicore:bool_impl
503fn main() {
504    true.t$0hen(15);
505}
506",
507        );
508        check_assist_not_applicable(
509            convert_bool_then_to_if,
510            r"
511//- minicore:bool_impl
512fn main() {
513    true.t$0hen(|| 15, 15);
514}
515",
516        );
517    }
518
519    #[test]
520    fn convert_bool_then_to_if_simple() {
521        check_assist(
522            convert_bool_then_to_if,
523            r"
524//- minicore:bool_impl
525fn main() {
526    true.t$0hen(|| 15)
527}
528",
529            r"
530fn main() {
531    if true {
532        Some(15)
533    } else {
534        None
535    }
536}
537",
538        );
539        check_assist(
540            convert_bool_then_to_if,
541            r"
542//- minicore:bool_impl
543fn main() {
544    true.t$0hen(|| {
545        15
546    })
547}
548",
549            r"
550fn main() {
551    if true {
552        Some(15)
553    } else {
554        None
555    }
556}
557",
558        );
559    }
560
561    #[test]
562    fn convert_bool_then_to_if_tails() {
563        check_assist(
564            convert_bool_then_to_if,
565            r"
566//- minicore:bool_impl
567fn main() {
568    true.t$0hen(|| {
569        loop {
570            if false {
571                break 0;
572            }
573            break 15;
574        }
575    })
576}
577",
578            r"
579fn main() {
580    if true {
581        loop {
582            if false {
583                break Some(0);
584            }
585            break Some(15);
586        }
587    } else {
588        None
589    }
590}
591",
592        );
593    }
594}