Skip to main content

ide_assists/handlers/
convert_bool_then.rs

1use hir::{AsAssocItem, Semantics, sym};
2use ide_db::{
3    RootDatabase,
4    famous_defs::FamousDefs,
5    syntax_helpers::node_ext::{
6        block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
7    },
8};
9use itertools::Itertools;
10use syntax::{
11    AstNode, SyntaxNode,
12    ast::{self, HasArgList, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
13    syntax_editor::SyntaxEditor,
14};
15
16use crate::{
17    AssistContext, AssistId, Assists,
18    utils::{invert_boolean_expression, unwrap_trivial_block},
19};
20
21// Assist: convert_if_to_bool_then
22//
23// Converts an if expression into a corresponding `bool::then` call.
24//
25// ```
26// # //- minicore: option
27// fn main() {
28//     if$0 cond {
29//         Some(val)
30//     } else {
31//         None
32//     }
33// }
34// ```
35// ->
36// ```
37// fn main() {
38//     cond.then(|| val)
39// }
40// ```
41pub(crate) fn convert_if_to_bool_then(
42    acc: &mut Assists,
43    ctx: &AssistContext<'_, '_>,
44) -> Option<()> {
45    // FIXME applies to match as well
46    let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
47    if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
48        return None;
49    }
50
51    let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
52    let then = expr.then_branch()?;
53    let else_ = match expr.else_branch()? {
54        ast::ElseBranch::Block(b) => b,
55        ast::ElseBranch::IfExpr(_) => {
56            cov_mark::hit!(convert_if_to_bool_then_chain);
57            return None;
58        }
59    };
60
61    let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
62
63    let (invert_cond, closure_body) = match (
64        block_is_none_variant(&ctx.sema, &then, none_variant),
65        block_is_none_variant(&ctx.sema, &else_, none_variant),
66    ) {
67        (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
68        (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
69        _ => return None,
70    };
71
72    if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
73        cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
74        return None;
75    }
76
77    let target = expr.syntax().text_range();
78    acc.add(
79        AssistId::refactor_rewrite("convert_if_to_bool_then"),
80        "Convert `if` expression to `bool::then` call",
81        target,
82        |builder| {
83            let (editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body);
84            // Rewrite all `Some(e)` in tail position to `e`
85            for_each_tail_expr(&closure_body, &mut |e| {
86                let e = match e {
87                    ast::Expr::BreakExpr(e) => e.expr(),
88                    e @ ast::Expr::CallExpr(_) => Some(e.clone()),
89                    _ => None,
90                };
91                if let Some(ast::Expr::CallExpr(call)) = e
92                    && let Some(arg_list) = call.arg_list()
93                    && let Some(arg) = arg_list.args().next()
94                {
95                    editor.replace(call.syntax(), arg.syntax());
96                }
97            });
98            let edit = editor.finish();
99            let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap();
100
101            let editor = builder.make_editor(expr.syntax());
102            let make = editor.make();
103            let closure_body = match closure_body {
104                ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
105                e => e,
106            };
107            let cond = if invert_cond { invert_boolean_expression(make, cond) } else { cond };
108
109            let parenthesize = matches!(
110                cond,
111                ast::Expr::BinExpr(_)
112                    | ast::Expr::BlockExpr(_)
113                    | ast::Expr::BreakExpr(_)
114                    | ast::Expr::CastExpr(_)
115                    | ast::Expr::ClosureExpr(_)
116                    | ast::Expr::ContinueExpr(_)
117                    | ast::Expr::ForExpr(_)
118                    | ast::Expr::IfExpr(_)
119                    | ast::Expr::LoopExpr(_)
120                    | ast::Expr::MacroExpr(_)
121                    | ast::Expr::MatchExpr(_)
122                    | ast::Expr::PrefixExpr(_)
123                    | ast::Expr::RangeExpr(_)
124                    | ast::Expr::RefExpr(_)
125                    | ast::Expr::ReturnExpr(_)
126                    | ast::Expr::WhileExpr(_)
127                    | ast::Expr::YieldExpr(_)
128            );
129
130            let cond = if parenthesize { make.expr_paren(cond).into() } else { cond };
131            let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into()));
132            let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list);
133            editor.replace(expr.syntax(), mcall.syntax());
134            builder.add_file_edits(ctx.vfs_file_id(), editor);
135        },
136    )
137}
138
139// Assist: convert_bool_then_to_if
140//
141// Converts a `bool::then` method call to an equivalent if expression.
142//
143// ```
144// # //- minicore: bool_impl
145// fn main() {
146//     (0 == 0).then$0(|| val)
147// }
148// ```
149// ->
150// ```
151// fn main() {
152//     if 0 == 0 {
153//         Some(val)
154//     } else {
155//         None
156//     }
157// }
158// ```
159pub(crate) fn convert_bool_then_to_if(
160    acc: &mut Assists,
161    ctx: &AssistContext<'_, '_>,
162) -> Option<()> {
163    let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
164    let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
165    let receiver = mcall.receiver()?;
166    // FIXME: rewrite in terms of `#![feature(exact_length_collection)]`. See: #149266
167    let closure_body = Itertools::exactly_one(mcall.arg_list()?.args()).ok()?;
168    let closure_body = match closure_body {
169        ast::Expr::ClosureExpr(expr) => expr.body()?,
170        _ => return None,
171    };
172    // Verify this is `bool::then` that is being called.
173    let func = ctx.sema.resolve_method_call(&mcall)?;
174    if func.name(ctx.sema.db) != sym::then {
175        return None;
176    }
177    let assoc = func.as_assoc_item(ctx.sema.db)?;
178    if !assoc.implementing_ty(ctx.sema.db)?.is_bool() {
179        return None;
180    }
181
182    let target = mcall.syntax().text_range();
183    acc.add(
184        AssistId::refactor_rewrite("convert_bool_then_to_if"),
185        "Convert `bool::then` call to `if`",
186        target,
187        |builder| {
188            let mapless_make = SyntaxFactory::without_mappings();
189            let closure_body = match closure_body.reset_indent() {
190                ast::Expr::BlockExpr(block) => block,
191                e => mapless_make.block_expr(None, Some(e)),
192            };
193
194            let (editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body);
195            // Wrap all tails in `Some(...)`
196            let none_path = mapless_make.expr_path(mapless_make.ident_path("None"));
197            let some_path = mapless_make.expr_path(mapless_make.ident_path("Some"));
198            for_each_tail_expr(&ast::Expr::BlockExpr(closure_body), &mut |e| {
199                let e = match e {
200                    ast::Expr::BreakExpr(e) => e.expr(),
201                    ast::Expr::ReturnExpr(e) => e.expr(),
202                    _ => Some(e.clone()),
203                };
204                if let Some(expr) = e {
205                    editor.replace(
206                        expr.syntax().clone(),
207                        mapless_make
208                            .expr_call(some_path.clone(), mapless_make.arg_list(Some(expr)))
209                            .syntax()
210                            .clone(),
211                    );
212                }
213            });
214            let edit = editor.finish();
215            let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap();
216
217            let editor = builder.make_editor(mcall.syntax());
218            let make = editor.make();
219
220            let cond = match &receiver {
221                ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
222                _ => receiver,
223            };
224            let if_expr = make
225                .expr_if(
226                    cond,
227                    closure_body,
228                    Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))),
229                )
230                .indent(mcall.indent_level());
231            editor.replace(mcall.syntax().clone(), if_expr.syntax().clone());
232            builder.add_file_edits(ctx.vfs_file_id(), editor);
233        },
234    )
235}
236
237fn option_variants(
238    sema: &Semantics<'_, RootDatabase>,
239    expr: &SyntaxNode,
240) -> Option<(hir::EnumVariant, hir::EnumVariant)> {
241    let fam = FamousDefs(sema, sema.scope(expr)?.krate());
242    let option_variants = fam.core_option_Option()?.variants(sema.db);
243    match &*option_variants {
244        &[variant0, variant1] => Some(if variant0.name(sema.db) == sym::None {
245            (variant0, variant1)
246        } else {
247            (variant1, variant0)
248        }),
249        _ => None,
250    }
251}
252
253/// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
254/// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
255fn is_invalid_body(
256    sema: &Semantics<'_, RootDatabase>,
257    some_variant: hir::EnumVariant,
258    expr: &ast::Expr,
259) -> bool {
260    let mut invalid = false;
261    preorder_expr(expr, &mut |e| {
262        invalid |=
263            matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
264        invalid
265    });
266    if !invalid {
267        for_each_tail_expr(expr, &mut |e| {
268            if invalid {
269                return;
270            }
271            let e = match e {
272                ast::Expr::BreakExpr(e) => e.expr(),
273                e @ ast::Expr::CallExpr(_) => Some(e.clone()),
274                _ => None,
275            };
276            if let Some(ast::Expr::CallExpr(call)) = e
277                && let Some(ast::Expr::PathExpr(p)) = call.expr()
278            {
279                let res = p.path().and_then(|p| sema.resolve_path(&p));
280                if let Some(hir::PathResolution::Def(hir::ModuleDef::EnumVariant(v))) = res {
281                    return invalid |= v != some_variant;
282                }
283            }
284            invalid = true
285        });
286    }
287    invalid
288}
289
290fn block_is_none_variant(
291    sema: &Semantics<'_, RootDatabase>,
292    block: &ast::BlockExpr,
293    none_variant: hir::EnumVariant,
294) -> bool {
295    block_as_lone_tail(block).and_then(|e| match e {
296        ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
297            hir::PathResolution::Def(hir::ModuleDef::EnumVariant(v)) => Some(v),
298            _ => None,
299        },
300        _ => None,
301    }) == Some(none_variant)
302}
303
304#[cfg(test)]
305mod tests {
306    use crate::tests::{check_assist, check_assist_not_applicable};
307
308    use super::*;
309
310    #[test]
311    fn convert_if_to_bool_then_simple() {
312        check_assist(
313            convert_if_to_bool_then,
314            r"
315//- minicore:option
316fn main() {
317    if$0 true {
318        Some(15)
319    } else {
320        None
321    }
322}
323",
324            r"
325fn main() {
326    true.then(|| 15)
327}
328",
329        );
330    }
331
332    #[test]
333    fn convert_if_to_bool_then_invert() {
334        check_assist(
335            convert_if_to_bool_then,
336            r"
337//- minicore:option
338fn main() {
339    if$0 true {
340        None
341    } else {
342        Some(15)
343    }
344}
345",
346            r"
347fn main() {
348    false.then(|| 15)
349}
350",
351        );
352    }
353
354    #[test]
355    fn convert_if_to_bool_then_none_none() {
356        check_assist_not_applicable(
357            convert_if_to_bool_then,
358            r"
359//- minicore:option
360fn main() {
361    if$0 true {
362        None
363    } else {
364        None
365    }
366}
367",
368        );
369    }
370
371    #[test]
372    fn convert_if_to_bool_then_some_some() {
373        check_assist_not_applicable(
374            convert_if_to_bool_then,
375            r"
376//- minicore:option
377fn main() {
378    if$0 true {
379        Some(15)
380    } else {
381        Some(15)
382    }
383}
384",
385        );
386    }
387
388    #[test]
389    fn convert_if_to_bool_then_mixed() {
390        check_assist_not_applicable(
391            convert_if_to_bool_then,
392            r"
393//- minicore:option
394fn main() {
395    if$0 true {
396        if true {
397            Some(15)
398        } else {
399            None
400        }
401    } else {
402        None
403    }
404}
405",
406        );
407    }
408
409    #[test]
410    fn convert_if_to_bool_then_chain() {
411        cov_mark::check!(convert_if_to_bool_then_chain);
412        check_assist_not_applicable(
413            convert_if_to_bool_then,
414            r"
415//- minicore:option
416fn main() {
417    if$0 true {
418        Some(15)
419    } else if true {
420        None
421    } else {
422        None
423    }
424}
425",
426        );
427    }
428
429    #[test]
430    fn convert_if_to_bool_then_pattern_cond() {
431        check_assist_not_applicable(
432            convert_if_to_bool_then,
433            r"
434//- minicore:option
435fn main() {
436    if$0 let true = true {
437        Some(15)
438    } else {
439        None
440    }
441}
442",
443        );
444    }
445
446    #[test]
447    fn convert_if_to_bool_then_pattern_invalid_body() {
448        cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
449        check_assist_not_applicable(
450            convert_if_to_bool_then,
451            r"
452//- minicore:option
453fn make_me_an_option() -> Option<i32> { None }
454fn main() {
455    if$0 true {
456        if true {
457            make_me_an_option()
458        } else {
459            Some(15)
460        }
461    } else {
462        None
463    }
464}
465",
466        );
467        check_assist_not_applicable(
468            convert_if_to_bool_then,
469            r"
470//- minicore:option
471fn main() {
472    if$0 true {
473        if true {
474            return;
475        }
476        Some(15)
477    } else {
478        None
479    }
480}
481",
482        );
483    }
484
485    #[test]
486    fn convert_bool_then_to_if_inapplicable() {
487        check_assist_not_applicable(
488            convert_bool_then_to_if,
489            r"
490//- minicore:bool_impl
491fn main() {
492    0.t$0hen(|| 15);
493}
494",
495        );
496        check_assist_not_applicable(
497            convert_bool_then_to_if,
498            r"
499//- minicore:bool_impl
500fn main() {
501    true.t$0hen(15);
502}
503",
504        );
505        check_assist_not_applicable(
506            convert_bool_then_to_if,
507            r"
508//- minicore:bool_impl
509fn main() {
510    true.t$0hen(|| 15, 15);
511}
512",
513        );
514    }
515
516    #[test]
517    fn convert_bool_then_to_if_simple() {
518        check_assist(
519            convert_bool_then_to_if,
520            r"
521//- minicore:bool_impl
522fn main() {
523    true.t$0hen(|| 15)
524}
525",
526            r"
527fn main() {
528    if true {
529        Some(15)
530    } else {
531        None
532    }
533}
534",
535        );
536        check_assist(
537            convert_bool_then_to_if,
538            r"
539//- minicore:bool_impl
540fn main() {
541    true.t$0hen(|| {
542        15
543    })
544}
545",
546            r"
547fn main() {
548    if true {
549        Some(15)
550    } else {
551        None
552    }
553}
554",
555        );
556    }
557
558    #[test]
559    fn convert_bool_then_to_if_tails() {
560        check_assist(
561            convert_bool_then_to_if,
562            r"
563//- minicore:bool_impl
564fn main() {
565    true.t$0hen(|| {
566        loop {
567            if false {
568                break 0;
569            }
570            break 15;
571        }
572    })
573}
574",
575            r"
576fn main() {
577    if true {
578        loop {
579            if false {
580                break Some(0);
581            }
582            break Some(15);
583        }
584    } else {
585        None
586    }
587}
588",
589        );
590    }
591    #[test]
592    fn convert_if_to_bool_then_invert_method_call() {
593        check_assist(
594            convert_if_to_bool_then,
595            r"
596//- minicore:option
597fn main() {
598    let test = &[()];
599    let value = if$0 test.is_empty() { None } else { Some(()) };
600}
601",
602            r"
603fn main() {
604    let test = &[()];
605    let value = (!test.is_empty()).then(|| ());
606}
607",
608        );
609    }
610}