Skip to main content

ide_assists/handlers/
convert_bool_to_enum.rs

1use either::Either;
2use hir::ModuleDef;
3use ide_db::imports::insert_use::insert_use_with_editor;
4use ide_db::text_edit::TextRange;
5use ide_db::{
6    FileId, FxHashSet,
7    assists::AssistId,
8    defs::Definition,
9    helpers::mod_path_to_ast_with_factory,
10    imports::insert_use::ImportScope,
11    search::{FileReference, UsageSearchResult},
12    source_change::SourceChangeBuilder,
13};
14use itertools::Itertools;
15use syntax::ast::edit::AstNodeEdit;
16use syntax::ast::syntax_factory::SyntaxFactory;
17use syntax::{
18    AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
19    ast::{self, HasName, edit::IndentLevel},
20};
21
22use crate::{
23    assist_context::{AssistContext, Assists},
24    utils,
25};
26
27// Assist: convert_bool_to_enum
28//
29// This converts boolean local variables, fields, constants, and statics into a new
30// enum with two variants `Bool::True` and `Bool::False`, as well as replacing
31// all assignments with the variants and replacing all usages with `== Bool::True` or
32// `== Bool::False`.
33//
34// ```
35// fn main() {
36//     let $0bool = true;
37//
38//     if bool {
39//         println!("foo");
40//     }
41// }
42// ```
43// ->
44// ```
45// #[derive(PartialEq, Eq)]
46// enum Bool { True, False }
47//
48// fn main() {
49//     let bool = Bool::True;
50//
51//     if bool == Bool::True {
52//         println!("foo");
53//     }
54// }
55// ```
56pub(crate) fn convert_bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
57    let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
58        find_bool_node(ctx)?;
59    let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db());
60
61    let target = name.syntax().text_range();
62    acc.add(
63        AssistId::refactor_rewrite("convert_bool_to_enum"),
64        "Convert boolean to enum",
65        target,
66        |edit| {
67            let make = SyntaxFactory::without_mappings();
68            if let Some(ty) = &ty_annotation {
69                cov_mark::hit!(replaces_ty_annotation);
70                edit.replace(ty.syntax().text_range(), "Bool");
71            }
72
73            if let Some(initializer) = initializer {
74                replace_bool_expr(edit, initializer, &make);
75            }
76
77            let usages = definition.usages(&ctx.sema).all();
78            add_enum_def(edit, ctx, &usages, target_node, &target_module, &make);
79            let mut delayed_mutations = Vec::new();
80            replace_usages(
81                edit,
82                ctx,
83                usages,
84                definition,
85                &target_module,
86                &mut delayed_mutations,
87                &make,
88            );
89            for (file_id, scope, path) in delayed_mutations {
90                let editor = edit.make_editor(scope.as_syntax_node());
91                insert_use_with_editor(&scope, path, &ctx.config.insert_use, &editor);
92                edit.add_file_edits(file_id, editor);
93            }
94        },
95    )
96}
97
98struct BoolNodeData {
99    target_node: SyntaxNode,
100    name: ast::Name,
101    ty_annotation: Option<ast::Type>,
102    initializer: Option<ast::Expr>,
103    definition: Definition,
104}
105
106/// Attempts to find an appropriate node to apply the action to.
107fn find_bool_node(ctx: &AssistContext<'_, '_>) -> Option<BoolNodeData> {
108    let name = ctx.find_node_at_offset::<ast::Name>()?;
109
110    if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
111        let def = ctx.sema.to_def(&ident_pat)?;
112        if !def.ty(ctx.db()).is_bool() {
113            cov_mark::hit!(not_applicable_non_bool_local);
114            return None;
115        }
116
117        let local_definition = Definition::Local(def);
118        match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
119            Either::Left(param) => Some(BoolNodeData {
120                target_node: param.syntax().clone(),
121                name,
122                ty_annotation: param.ty(),
123                initializer: None,
124                definition: local_definition,
125            }),
126            Either::Right(let_stmt) => Some(BoolNodeData {
127                target_node: let_stmt.syntax().clone(),
128                name,
129                ty_annotation: let_stmt.ty(),
130                initializer: let_stmt.initializer(),
131                definition: local_definition,
132            }),
133        }
134    } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
135        let def = ctx.sema.to_def(&const_)?;
136        if !def.ty(ctx.db()).is_bool() {
137            cov_mark::hit!(not_applicable_non_bool_const);
138            return None;
139        }
140
141        Some(BoolNodeData {
142            target_node: const_.syntax().clone(),
143            name,
144            ty_annotation: const_.ty(),
145            initializer: const_.body(),
146            definition: Definition::Const(def),
147        })
148    } else if let Some(static_) = name.syntax().parent().and_then(ast::Static::cast) {
149        let def = ctx.sema.to_def(&static_)?;
150        if !def.ty(ctx.db()).is_bool() {
151            cov_mark::hit!(not_applicable_non_bool_static);
152            return None;
153        }
154
155        Some(BoolNodeData {
156            target_node: static_.syntax().clone(),
157            name,
158            ty_annotation: static_.ty(),
159            initializer: static_.body(),
160            definition: Definition::Static(def),
161        })
162    } else {
163        let field = name.syntax().parent().and_then(ast::RecordField::cast)?;
164        if field.name()? != name {
165            return None;
166        }
167
168        let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?;
169        let def = ctx.sema.to_def(&field)?;
170        if !def.ty(ctx.db()).is_bool() {
171            cov_mark::hit!(not_applicable_non_bool_field);
172            return None;
173        }
174        Some(BoolNodeData {
175            target_node: adt.syntax().clone(),
176            name,
177            ty_annotation: field.ty(),
178            initializer: None,
179            definition: Definition::Field(def),
180        })
181    }
182}
183
184fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr, make: &SyntaxFactory) {
185    let expr_range = expr.syntax().text_range();
186    let enum_expr = bool_expr_to_enum_expr(expr, make);
187    edit.replace(expr_range, enum_expr.syntax().text())
188}
189
190/// Converts an expression of type `bool` to one of the new enum type.
191fn bool_expr_to_enum_expr(expr: ast::Expr, make: &SyntaxFactory) -> ast::Expr {
192    let true_expr = make.expr_path(make.path_from_text("Bool::True"));
193    let false_expr = make.expr_path(make.path_from_text("Bool::False"));
194
195    if let ast::Expr::Literal(literal) = &expr {
196        match literal.kind() {
197            ast::LiteralKind::Bool(true) => true_expr,
198            ast::LiteralKind::Bool(false) => false_expr,
199            _ => expr,
200        }
201    } else {
202        make.expr_if(
203            expr,
204            make.tail_only_block_expr(true_expr),
205            Some(ast::ElseBranch::Block(make.tail_only_block_expr(false_expr))),
206        )
207        .into()
208    }
209}
210
211/// Replaces all usages of the target identifier, both when read and written to.
212fn replace_usages(
213    edit: &mut SourceChangeBuilder,
214    ctx: &AssistContext<'_, '_>,
215    usages: UsageSearchResult,
216    target_definition: Definition,
217    target_module: &hir::Module,
218    delayed_mutations: &mut Vec<(FileId, ImportScope, ast::Path)>,
219    make: &SyntaxFactory,
220) {
221    for (file_id, references) in usages {
222        let vfs_file_id = file_id.file_id(ctx.db());
223        edit.edit_file(vfs_file_id);
224
225        let refs_with_imports =
226            augment_references_with_imports(ctx, references, target_module, make);
227
228        refs_with_imports.into_iter().rev().for_each(
229            |FileReferenceWithImport { range, name, import_data }| {
230                // replace the usages in patterns and expressions
231                if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
232                    cov_mark::hit!(replaces_record_pat_shorthand);
233
234                    let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
235                    if let Some(def) = definition {
236                        replace_usages(
237                            edit,
238                            ctx,
239                            def.usages(&ctx.sema).all(),
240                            target_definition,
241                            target_module,
242                            delayed_mutations,
243                            make,
244                        )
245                    }
246                } else if let Some(initializer) = find_assignment_usage(&name) {
247                    cov_mark::hit!(replaces_assignment);
248
249                    replace_bool_expr(edit, initializer, make);
250                } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
251                    cov_mark::hit!(replaces_negation);
252
253                    edit.replace(
254                        prefix_expr.syntax().text_range(),
255                        format!("{inner_expr} == Bool::False"),
256                    );
257                } else if let Some((record_field, initializer)) = name
258                    .as_name_ref()
259                    .and_then(ast::RecordExprField::for_field_name)
260                    .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
261                    .and_then(|(got_field, _, _)| {
262                        find_record_expr_usage(&name, got_field, target_definition)
263                    })
264                {
265                    cov_mark::hit!(replaces_record_expr);
266
267                    let enum_expr = bool_expr_to_enum_expr(initializer, make);
268                    utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
269                } else if let Some(pat) = find_record_pat_field_usage(&name) {
270                    match pat {
271                        ast::Pat::IdentPat(ident_pat) => {
272                            cov_mark::hit!(replaces_record_pat);
273
274                            let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
275                            if let Some(def) = definition {
276                                replace_usages(
277                                    edit,
278                                    ctx,
279                                    def.usages(&ctx.sema).all(),
280                                    target_definition,
281                                    target_module,
282                                    delayed_mutations,
283                                    make,
284                                )
285                            }
286                        }
287                        ast::Pat::LiteralPat(literal_pat) => {
288                            cov_mark::hit!(replaces_literal_pat);
289
290                            if let Some(expr) = literal_pat.literal().and_then(|literal| {
291                                literal.syntax().ancestors().find_map(ast::Expr::cast)
292                            }) {
293                                replace_bool_expr(edit, expr, make);
294                            }
295                        }
296                        _ => (),
297                    }
298                } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
299                    edit.replace(ty_annotation.syntax().text_range(), "Bool");
300                    replace_bool_expr(edit, initializer, make);
301                } else if let Some(receiver) = find_method_call_expr_usage(&name) {
302                    edit.replace(
303                        receiver.syntax().text_range(),
304                        format!("({receiver} == Bool::True)"),
305                    );
306                } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
307                    // for any other usage in an expression, replace it with a check that it is the true variant
308                    if let Some((record_field, expr)) =
309                        name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
310                            |record_field| record_field.expr().map(|expr| (record_field, expr)),
311                        )
312                    {
313                        utils::replace_record_field_expr(
314                            ctx,
315                            edit,
316                            record_field,
317                            make.expr_bin_op(
318                                expr,
319                                ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
320                                make.expr_path(make.path_from_text("Bool::True")),
321                            ),
322                        );
323                    } else {
324                        edit.replace(range, format!("{} == Bool::True", name.text()));
325                    }
326                }
327
328                // add imports across modules where needed
329                if let Some((scope, path)) = import_data {
330                    delayed_mutations.push((vfs_file_id, scope, path));
331                }
332            },
333        )
334    }
335}
336
337struct FileReferenceWithImport {
338    range: TextRange,
339    name: ast::NameLike,
340    import_data: Option<(ImportScope, ast::Path)>,
341}
342
343fn augment_references_with_imports(
344    ctx: &AssistContext<'_, '_>,
345    references: Vec<FileReference>,
346    target_module: &hir::Module,
347    make: &SyntaxFactory,
348) -> Vec<FileReferenceWithImport> {
349    let mut visited_modules = FxHashSet::default();
350
351    let edition = target_module.krate(ctx.db()).edition(ctx.db());
352    references
353        .into_iter()
354        .filter_map(|FileReference { range, name, .. }| {
355            let name = name.into_name_like()?;
356            ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
357        })
358        .map(|(range, name, ref_module)| {
359            // if the referenced module is not the same as the target one and has not been seen before, add an import
360            let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
361                && !visited_modules.contains(&ref_module)
362            {
363                visited_modules.insert(ref_module);
364
365                ImportScope::find_insert_use_container(name.syntax(), &ctx.sema).and_then(
366                    |import_scope| {
367                        let cfg = ctx.config.find_path_config(
368                            ctx.sema.is_nightly(target_module.krate(ctx.sema.db)),
369                        );
370                        let path = ref_module
371                            .find_use_path(
372                                ctx.sema.db,
373                                ModuleDef::Module(*target_module),
374                                ctx.config.insert_use.prefix_kind,
375                                cfg,
376                            )
377                            .map(|mod_path| {
378                                make.path_concat(
379                                    mod_path_to_ast_with_factory(make, &mod_path, edition),
380                                    make.path_from_text("Bool"),
381                                )
382                            })?;
383
384                        Some((import_scope, path))
385                    },
386                )
387            } else {
388                None
389            };
390
391            FileReferenceWithImport { range, name, import_data }
392        })
393        .collect()
394}
395
396fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> {
397    let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?;
398
399    if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) {
400        cov_mark::hit!(dont_assign_incorrect_ref);
401        return None;
402    }
403
404    if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
405        bin_expr.rhs()
406    } else {
407        None
408    }
409}
410
411fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> {
412    let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
413
414    if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
415        cov_mark::hit!(dont_overwrite_expression_inside_negation);
416        return None;
417    }
418
419    if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
420        let inner_expr = prefix_expr.expr()?;
421        Some((prefix_expr, inner_expr))
422    } else {
423        None
424    }
425}
426
427fn find_record_expr_usage(
428    name: &ast::NameLike,
429    got_field: hir::Field,
430    target_definition: Definition,
431) -> Option<(ast::RecordExprField, ast::Expr)> {
432    let name_ref = name.as_name_ref()?;
433    let record_field = ast::RecordExprField::for_field_name(name_ref)?;
434    let initializer = record_field.expr()?;
435
436    match target_definition {
437        Definition::Field(expected_field) if got_field == expected_field => {
438            Some((record_field, initializer))
439        }
440        _ => None,
441    }
442}
443
444fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
445    let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?;
446    let pat = record_pat_field.pat()?;
447
448    match pat {
449        ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat),
450        _ => None,
451    }
452}
453
454fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
455    let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
456    const_.syntax().parent().and_then(ast::AssocItemList::cast)?;
457
458    Some((const_.ty()?, const_.body()?))
459}
460
461fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
462    let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
463    let receiver = method_call.receiver()?;
464
465    if !receiver.syntax().descendants().contains(name.syntax()) {
466        return None;
467    }
468
469    Some(receiver)
470}
471
472/// Adds the definition of the new enum before the target node.
473fn add_enum_def(
474    edit: &mut SourceChangeBuilder,
475    ctx: &AssistContext<'_, '_>,
476    usages: &UsageSearchResult,
477    target_node: SyntaxNode,
478    target_module: &hir::Module,
479    make: &SyntaxFactory,
480) -> Option<()> {
481    let insert_before = node_to_insert_before(target_node);
482
483    if ctx
484        .sema
485        .scope(&insert_before)?
486        .module()
487        .scope(ctx.db(), Some(*target_module))
488        .iter()
489        .any(|(name, _)| name.as_str() == "Bool")
490    {
491        return None;
492    }
493
494    let make_enum_pub = usages
495        .iter()
496        .flat_map(|(_, refs)| refs)
497        .filter_map(|FileReference { name, .. }| {
498            let name = name.clone().into_name_like()?;
499            ctx.sema.scope(name.syntax()).map(|scope| scope.module())
500        })
501        .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
502
503    let indent = IndentLevel::from_node(&insert_before);
504    let enum_def = make_bool_enum(make_enum_pub, make).reset_indent().indent(indent);
505
506    edit.insert(
507        insert_before.text_range().start(),
508        format!("{}\n\n{indent}", enum_def.syntax().text()),
509    );
510
511    Some(())
512}
513
514/// Finds where to put the new enum definition.
515/// Tries to find the ast node at the nearest module or at top-level, otherwise just
516/// returns the input node.
517fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
518    target_node
519        .ancestors()
520        .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
521        .filter(|it| ast::Item::can_cast(it.kind()))
522        .last()
523        .unwrap_or(target_node)
524}
525
526fn make_bool_enum(make_pub: bool, make: &SyntaxFactory) -> ast::Enum {
527    let derive_eq = make.attr_outer(make.meta_token_tree(
528        make.ident_path("derive"),
529        make.token_tree(
530            T!['('],
531            vec![
532                NodeOrToken::Token(make.ident("PartialEq")),
533                NodeOrToken::Token(make.token(T![,])),
534                NodeOrToken::Token(make.whitespace(" ")),
535                NodeOrToken::Token(make.ident("Eq")),
536            ],
537        ),
538    ));
539    make.item_enum(
540        [derive_eq],
541        if make_pub { Some(make.visibility_pub()) } else { None },
542        make.name("Bool"),
543        None,
544        None,
545        make.variant_list(vec![
546            make.variant(None, make.name("True"), None, None),
547            make.variant(None, make.name("False"), None, None),
548        ]),
549    )
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    use crate::tests::{check_assist, check_assist_not_applicable};
557
558    #[test]
559    fn parameter_with_first_param_usage() {
560        check_assist(
561            convert_bool_to_enum,
562            r#"
563fn function($0foo: bool, bar: bool) {
564    if foo {
565        println!("foo");
566    }
567}
568"#,
569            r#"
570#[derive(PartialEq, Eq)]
571enum Bool { True, False }
572
573fn function(foo: Bool, bar: bool) {
574    if foo == Bool::True {
575        println!("foo");
576    }
577}
578"#,
579        )
580    }
581
582    #[test]
583    fn no_duplicate_enums() {
584        check_assist(
585            convert_bool_to_enum,
586            r#"
587#[derive(PartialEq, Eq)]
588enum Bool { True, False }
589
590fn function(foo: bool, $0bar: bool) {
591    if bar {
592        println!("bar");
593    }
594}
595"#,
596            r#"
597#[derive(PartialEq, Eq)]
598enum Bool { True, False }
599
600fn function(foo: bool, bar: Bool) {
601    if bar == Bool::True {
602        println!("bar");
603    }
604}
605"#,
606        )
607    }
608
609    #[test]
610    fn parameter_with_last_param_usage() {
611        check_assist(
612            convert_bool_to_enum,
613            r#"
614fn function(foo: bool, $0bar: bool) {
615    if bar {
616        println!("bar");
617    }
618}
619"#,
620            r#"
621#[derive(PartialEq, Eq)]
622enum Bool { True, False }
623
624fn function(foo: bool, bar: Bool) {
625    if bar == Bool::True {
626        println!("bar");
627    }
628}
629"#,
630        )
631    }
632
633    #[test]
634    fn parameter_with_middle_param_usage() {
635        check_assist(
636            convert_bool_to_enum,
637            r#"
638fn function(foo: bool, $0bar: bool, baz: bool) {
639    if bar {
640        println!("bar");
641    }
642}
643"#,
644            r#"
645#[derive(PartialEq, Eq)]
646enum Bool { True, False }
647
648fn function(foo: bool, bar: Bool, baz: bool) {
649    if bar == Bool::True {
650        println!("bar");
651    }
652}
653"#,
654        )
655    }
656
657    #[test]
658    fn parameter_with_closure_usage() {
659        check_assist(
660            convert_bool_to_enum,
661            r#"
662fn main() {
663    let foo = |$0bar: bool| bar;
664}
665"#,
666            r#"
667#[derive(PartialEq, Eq)]
668enum Bool { True, False }
669
670fn main() {
671    let foo = |bar: Bool| bar == Bool::True;
672}
673"#,
674        )
675    }
676
677    #[test]
678    fn local_variable_with_usage() {
679        check_assist(
680            convert_bool_to_enum,
681            r#"
682fn main() {
683    let $0foo = true;
684
685    if foo {
686        println!("foo");
687    }
688}
689"#,
690            r#"
691#[derive(PartialEq, Eq)]
692enum Bool { True, False }
693
694fn main() {
695    let foo = Bool::True;
696
697    if foo == Bool::True {
698        println!("foo");
699    }
700}
701"#,
702        )
703    }
704
705    #[test]
706    fn local_variable_with_usage_negated() {
707        cov_mark::check!(replaces_negation);
708        check_assist(
709            convert_bool_to_enum,
710            r#"
711fn main() {
712    let $0foo = true;
713
714    if !foo {
715        println!("foo");
716    }
717}
718"#,
719            r#"
720#[derive(PartialEq, Eq)]
721enum Bool { True, False }
722
723fn main() {
724    let foo = Bool::True;
725
726    if foo == Bool::False {
727        println!("foo");
728    }
729}
730"#,
731        )
732    }
733
734    #[test]
735    fn local_variable_with_type_annotation() {
736        cov_mark::check!(replaces_ty_annotation);
737        check_assist(
738            convert_bool_to_enum,
739            r#"
740fn main() {
741    let $0foo: bool = false;
742}
743"#,
744            r#"
745#[derive(PartialEq, Eq)]
746enum Bool { True, False }
747
748fn main() {
749    let foo: Bool = Bool::False;
750}
751"#,
752        )
753    }
754
755    #[test]
756    fn local_variable_with_non_literal_initializer() {
757        check_assist(
758            convert_bool_to_enum,
759            r#"
760fn main() {
761    let $0foo = 1 == 2;
762}
763"#,
764            r#"
765#[derive(PartialEq, Eq)]
766enum Bool { True, False }
767
768fn main() {
769    let foo = if 1 == 2 { Bool::True } else { Bool::False };
770}
771"#,
772        )
773    }
774
775    #[test]
776    fn local_variable_binexpr_usage() {
777        check_assist(
778            convert_bool_to_enum,
779            r#"
780fn main() {
781    let $0foo = false;
782    let bar = true;
783
784    if !foo && bar {
785        println!("foobar");
786    }
787}
788"#,
789            r#"
790#[derive(PartialEq, Eq)]
791enum Bool { True, False }
792
793fn main() {
794    let foo = Bool::False;
795    let bar = true;
796
797    if foo == Bool::False && bar {
798        println!("foobar");
799    }
800}
801"#,
802        )
803    }
804
805    #[test]
806    fn local_variable_unop_usage() {
807        check_assist(
808            convert_bool_to_enum,
809            r#"
810fn main() {
811    let $0foo = true;
812
813    if *&foo {
814        println!("foobar");
815    }
816}
817"#,
818            r#"
819#[derive(PartialEq, Eq)]
820enum Bool { True, False }
821
822fn main() {
823    let foo = Bool::True;
824
825    if *&foo == Bool::True {
826        println!("foobar");
827    }
828}
829"#,
830        )
831    }
832
833    #[test]
834    fn local_variable_assigned_later() {
835        cov_mark::check!(replaces_assignment);
836        check_assist(
837            convert_bool_to_enum,
838            r#"
839fn main() {
840    let $0foo: bool;
841    foo = true;
842}
843"#,
844            r#"
845#[derive(PartialEq, Eq)]
846enum Bool { True, False }
847
848fn main() {
849    let foo: Bool;
850    foo = Bool::True;
851}
852"#,
853        )
854    }
855
856    #[test]
857    fn local_variable_does_not_apply_recursively() {
858        check_assist(
859            convert_bool_to_enum,
860            r#"
861fn main() {
862    let $0foo = true;
863    let bar = !foo;
864
865    if bar {
866        println!("bar");
867    }
868}
869"#,
870            r#"
871#[derive(PartialEq, Eq)]
872enum Bool { True, False }
873
874fn main() {
875    let foo = Bool::True;
876    let bar = foo == Bool::False;
877
878    if bar {
879        println!("bar");
880    }
881}
882"#,
883        )
884    }
885
886    #[test]
887    fn local_variable_nested_in_negation() {
888        cov_mark::check!(dont_overwrite_expression_inside_negation);
889        check_assist(
890            convert_bool_to_enum,
891            r#"
892fn main() {
893    if !"foo".chars().any(|c| {
894        let $0foo = true;
895        foo
896    }) {
897        println!("foo");
898    }
899}
900"#,
901            r#"
902#[derive(PartialEq, Eq)]
903enum Bool { True, False }
904
905fn main() {
906    if !"foo".chars().any(|c| {
907        let foo = Bool::True;
908        foo == Bool::True
909    }) {
910        println!("foo");
911    }
912}
913"#,
914        )
915    }
916
917    #[test]
918    fn local_variable_non_bool() {
919        cov_mark::check!(not_applicable_non_bool_local);
920        check_assist_not_applicable(
921            convert_bool_to_enum,
922            r#"
923fn main() {
924    let $0foo = 1;
925}
926"#,
927        )
928    }
929
930    #[test]
931    fn local_variable_cursor_not_on_ident() {
932        check_assist_not_applicable(
933            convert_bool_to_enum,
934            r#"
935fn main() {
936    let foo = $0true;
937}
938"#,
939        )
940    }
941
942    #[test]
943    fn local_variable_non_ident_pat() {
944        check_assist_not_applicable(
945            convert_bool_to_enum,
946            r#"
947fn main() {
948    let ($0foo, bar) = (true, false);
949}
950"#,
951        )
952    }
953
954    #[test]
955    fn local_var_init_struct_usage() {
956        check_assist(
957            convert_bool_to_enum,
958            r#"
959struct Foo {
960    foo: bool,
961}
962
963fn main() {
964    let $0foo = true;
965    let s = Foo { foo };
966}
967"#,
968            r#"
969struct Foo {
970    foo: bool,
971}
972
973#[derive(PartialEq, Eq)]
974enum Bool { True, False }
975
976fn main() {
977    let foo = Bool::True;
978    let s = Foo { foo: foo == Bool::True };
979}
980"#,
981        )
982    }
983
984    #[test]
985    fn local_var_init_struct_usage_in_macro() {
986        check_assist(
987            convert_bool_to_enum,
988            r#"
989struct Struct {
990    boolean: bool,
991}
992
993macro_rules! identity {
994    ($body:expr) => {
995        $body
996    }
997}
998
999fn new() -> Struct {
1000    let $0boolean = true;
1001    identity![Struct { boolean }]
1002}
1003"#,
1004            r#"
1005struct Struct {
1006    boolean: bool,
1007}
1008
1009macro_rules! identity {
1010    ($body:expr) => {
1011        $body
1012    }
1013}
1014
1015#[derive(PartialEq, Eq)]
1016enum Bool { True, False }
1017
1018fn new() -> Struct {
1019    let boolean = Bool::True;
1020    identity![Struct { boolean: boolean == Bool::True }]
1021}
1022"#,
1023        )
1024    }
1025
1026    #[test]
1027    fn field_struct_basic() {
1028        cov_mark::check!(replaces_record_expr);
1029        check_assist(
1030            convert_bool_to_enum,
1031            r#"
1032struct Foo {
1033    $0bar: bool,
1034    baz: bool,
1035}
1036
1037fn main() {
1038    let foo = Foo { bar: true, baz: false };
1039
1040    if foo.bar {
1041        println!("foo");
1042    }
1043}
1044"#,
1045            r#"
1046#[derive(PartialEq, Eq)]
1047enum Bool { True, False }
1048
1049struct Foo {
1050    bar: Bool,
1051    baz: bool,
1052}
1053
1054fn main() {
1055    let foo = Foo { bar: Bool::True, baz: false };
1056
1057    if foo.bar == Bool::True {
1058        println!("foo");
1059    }
1060}
1061"#,
1062        )
1063    }
1064
1065    #[test]
1066    fn field_enum_basic() {
1067        cov_mark::check!(replaces_record_pat);
1068        check_assist(
1069            convert_bool_to_enum,
1070            r#"
1071enum Foo {
1072    Foo,
1073    Bar { $0bar: bool },
1074}
1075
1076fn main() {
1077    let foo = Foo::Bar { bar: true };
1078
1079    if let Foo::Bar { bar: baz } = foo {
1080        if baz {
1081            println!("foo");
1082        }
1083    }
1084}
1085"#,
1086            r#"
1087#[derive(PartialEq, Eq)]
1088enum Bool { True, False }
1089
1090enum Foo {
1091    Foo,
1092    Bar { bar: Bool },
1093}
1094
1095fn main() {
1096    let foo = Foo::Bar { bar: Bool::True };
1097
1098    if let Foo::Bar { bar: baz } = foo {
1099        if baz == Bool::True {
1100            println!("foo");
1101        }
1102    }
1103}
1104"#,
1105        )
1106    }
1107
1108    #[test]
1109    fn field_enum_cross_file() {
1110        // FIXME: The import is missing
1111        check_assist(
1112            convert_bool_to_enum,
1113            r#"
1114//- /foo.rs
1115pub enum Foo {
1116    Foo,
1117    Bar { $0bar: bool },
1118}
1119
1120fn foo() {
1121    let foo = Foo::Bar { bar: true };
1122}
1123
1124//- /main.rs
1125use foo::Foo;
1126
1127mod foo;
1128
1129fn main() {
1130    let foo = Foo::Bar { bar: false };
1131}
1132"#,
1133            r#"
1134//- /foo.rs
1135#[derive(PartialEq, Eq)]
1136pub enum Bool { True, False }
1137
1138pub enum Foo {
1139    Foo,
1140    Bar { bar: Bool },
1141}
1142
1143fn foo() {
1144    let foo = Foo::Bar { bar: Bool::True };
1145}
1146
1147//- /main.rs
1148use foo::{Bool, Foo};
1149
1150mod foo;
1151
1152fn main() {
1153    let foo = Foo::Bar { bar: Bool::False };
1154}
1155"#,
1156        )
1157    }
1158
1159    #[test]
1160    fn field_enum_shorthand() {
1161        cov_mark::check!(replaces_record_pat_shorthand);
1162        check_assist(
1163            convert_bool_to_enum,
1164            r#"
1165enum Foo {
1166    Foo,
1167    Bar { $0bar: bool },
1168}
1169
1170fn main() {
1171    let foo = Foo::Bar { bar: true };
1172
1173    match foo {
1174        Foo::Bar { bar } => {
1175            if bar {
1176                println!("foo");
1177            }
1178        }
1179        _ => (),
1180    }
1181}
1182"#,
1183            r#"
1184#[derive(PartialEq, Eq)]
1185enum Bool { True, False }
1186
1187enum Foo {
1188    Foo,
1189    Bar { bar: Bool },
1190}
1191
1192fn main() {
1193    let foo = Foo::Bar { bar: Bool::True };
1194
1195    match foo {
1196        Foo::Bar { bar } => {
1197            if bar == Bool::True {
1198                println!("foo");
1199            }
1200        }
1201        _ => (),
1202    }
1203}
1204"#,
1205        )
1206    }
1207
1208    #[test]
1209    fn field_enum_replaces_literal_patterns() {
1210        cov_mark::check!(replaces_literal_pat);
1211        check_assist(
1212            convert_bool_to_enum,
1213            r#"
1214enum Foo {
1215    Foo,
1216    Bar { $0bar: bool },
1217}
1218
1219fn main() {
1220    let foo = Foo::Bar { bar: true };
1221
1222    if let Foo::Bar { bar: true } = foo {
1223        println!("foo");
1224    }
1225}
1226"#,
1227            r#"
1228#[derive(PartialEq, Eq)]
1229enum Bool { True, False }
1230
1231enum Foo {
1232    Foo,
1233    Bar { bar: Bool },
1234}
1235
1236fn main() {
1237    let foo = Foo::Bar { bar: Bool::True };
1238
1239    if let Foo::Bar { bar: Bool::True } = foo {
1240        println!("foo");
1241    }
1242}
1243"#,
1244        )
1245    }
1246
1247    #[test]
1248    fn field_enum_keeps_wildcard_patterns() {
1249        check_assist(
1250            convert_bool_to_enum,
1251            r#"
1252enum Foo {
1253    Foo,
1254    Bar { $0bar: bool },
1255}
1256
1257fn main() {
1258    let foo = Foo::Bar { bar: true };
1259
1260    if let Foo::Bar { bar: _ } = foo {
1261        println!("foo");
1262    }
1263}
1264"#,
1265            r#"
1266#[derive(PartialEq, Eq)]
1267enum Bool { True, False }
1268
1269enum Foo {
1270    Foo,
1271    Bar { bar: Bool },
1272}
1273
1274fn main() {
1275    let foo = Foo::Bar { bar: Bool::True };
1276
1277    if let Foo::Bar { bar: _ } = foo {
1278        println!("foo");
1279    }
1280}
1281"#,
1282        )
1283    }
1284
1285    #[test]
1286    fn field_union_basic() {
1287        check_assist(
1288            convert_bool_to_enum,
1289            r#"
1290union Foo {
1291    $0foo: bool,
1292    bar: usize,
1293}
1294
1295fn main() {
1296    let foo = Foo { foo: true };
1297
1298    if unsafe { foo.foo } {
1299        println!("foo");
1300    }
1301}
1302"#,
1303            r#"
1304#[derive(PartialEq, Eq)]
1305enum Bool { True, False }
1306
1307union Foo {
1308    foo: Bool,
1309    bar: usize,
1310}
1311
1312fn main() {
1313    let foo = Foo { foo: Bool::True };
1314
1315    if unsafe { foo.foo == Bool::True } {
1316        println!("foo");
1317    }
1318}
1319"#,
1320        )
1321    }
1322
1323    #[test]
1324    fn field_negated() {
1325        check_assist(
1326            convert_bool_to_enum,
1327            r#"
1328struct Foo {
1329    $0bar: bool,
1330}
1331
1332fn main() {
1333    let foo = Foo { bar: false };
1334
1335    if !foo.bar {
1336        println!("foo");
1337    }
1338}
1339"#,
1340            r#"
1341#[derive(PartialEq, Eq)]
1342enum Bool { True, False }
1343
1344struct Foo {
1345    bar: Bool,
1346}
1347
1348fn main() {
1349    let foo = Foo { bar: Bool::False };
1350
1351    if foo.bar == Bool::False {
1352        println!("foo");
1353    }
1354}
1355"#,
1356        )
1357    }
1358
1359    #[test]
1360    fn field_in_mod_properly_indented() {
1361        check_assist(
1362            convert_bool_to_enum,
1363            r#"
1364mod foo {
1365    struct Bar {
1366        $0baz: bool,
1367    }
1368
1369    impl Bar {
1370        fn new(baz: bool) -> Self {
1371            Self { baz }
1372        }
1373    }
1374}
1375"#,
1376            r#"
1377mod foo {
1378    #[derive(PartialEq, Eq)]
1379    enum Bool { True, False }
1380
1381    struct Bar {
1382        baz: Bool,
1383    }
1384
1385    impl Bar {
1386        fn new(baz: bool) -> Self {
1387            Self { baz: if baz { Bool::True } else { Bool::False } }
1388        }
1389    }
1390}
1391"#,
1392        )
1393    }
1394
1395    #[test]
1396    fn field_multiple_initializations() {
1397        check_assist(
1398            convert_bool_to_enum,
1399            r#"
1400struct Foo {
1401    $0bar: bool,
1402    baz: bool,
1403}
1404
1405fn main() {
1406    let foo1 = Foo { bar: true, baz: false };
1407    let foo2 = Foo { bar: false, baz: false };
1408
1409    if foo1.bar && foo2.bar {
1410        println!("foo");
1411    }
1412}
1413"#,
1414            r#"
1415#[derive(PartialEq, Eq)]
1416enum Bool { True, False }
1417
1418struct Foo {
1419    bar: Bool,
1420    baz: bool,
1421}
1422
1423fn main() {
1424    let foo1 = Foo { bar: Bool::True, baz: false };
1425    let foo2 = Foo { bar: Bool::False, baz: false };
1426
1427    if foo1.bar == Bool::True && foo2.bar == Bool::True {
1428        println!("foo");
1429    }
1430}
1431"#,
1432        )
1433    }
1434
1435    #[test]
1436    fn field_assigned_to_another() {
1437        cov_mark::check!(dont_assign_incorrect_ref);
1438        check_assist(
1439            convert_bool_to_enum,
1440            r#"
1441struct Foo {
1442    $0foo: bool,
1443}
1444
1445struct Bar {
1446    bar: bool,
1447}
1448
1449fn main() {
1450    let foo = Foo { foo: true };
1451    let mut bar = Bar { bar: true };
1452
1453    bar.bar = foo.foo;
1454}
1455"#,
1456            r#"
1457#[derive(PartialEq, Eq)]
1458enum Bool { True, False }
1459
1460struct Foo {
1461    foo: Bool,
1462}
1463
1464struct Bar {
1465    bar: bool,
1466}
1467
1468fn main() {
1469    let foo = Foo { foo: Bool::True };
1470    let mut bar = Bar { bar: true };
1471
1472    bar.bar = foo.foo == Bool::True;
1473}
1474"#,
1475        )
1476    }
1477
1478    #[test]
1479    fn field_initialized_with_other() {
1480        check_assist(
1481            convert_bool_to_enum,
1482            r#"
1483struct Foo {
1484    $0foo: bool,
1485}
1486
1487struct Bar {
1488    bar: bool,
1489}
1490
1491fn main() {
1492    let foo = Foo { foo: true };
1493    let bar = Bar { bar: foo.foo };
1494}
1495"#,
1496            r#"
1497#[derive(PartialEq, Eq)]
1498enum Bool { True, False }
1499
1500struct Foo {
1501    foo: Bool,
1502}
1503
1504struct Bar {
1505    bar: bool,
1506}
1507
1508fn main() {
1509    let foo = Foo { foo: Bool::True };
1510    let bar = Bar { bar: foo.foo == Bool::True };
1511}
1512"#,
1513        )
1514    }
1515
1516    #[test]
1517    fn field_method_chain_usage() {
1518        check_assist(
1519            convert_bool_to_enum,
1520            r#"
1521struct Foo {
1522    $0bool: bool,
1523}
1524
1525fn main() {
1526    let foo = Foo { bool: true };
1527
1528    foo.bool.then(|| 2);
1529}
1530"#,
1531            r#"
1532#[derive(PartialEq, Eq)]
1533enum Bool { True, False }
1534
1535struct Foo {
1536    bool: Bool,
1537}
1538
1539fn main() {
1540    let foo = Foo { bool: Bool::True };
1541
1542    (foo.bool == Bool::True).then(|| 2);
1543}
1544"#,
1545        )
1546    }
1547
1548    #[test]
1549    fn field_in_macro() {
1550        check_assist(
1551            convert_bool_to_enum,
1552            r#"
1553struct Struct {
1554    $0boolean: bool,
1555}
1556
1557fn boolean(x: Struct) {
1558    let Struct { boolean } = x;
1559}
1560
1561macro_rules! identity { ($body:expr) => { $body } }
1562
1563fn new() -> Struct {
1564    identity!(Struct { boolean: true })
1565}
1566"#,
1567            r#"
1568#[derive(PartialEq, Eq)]
1569enum Bool { True, False }
1570
1571struct Struct {
1572    boolean: Bool,
1573}
1574
1575fn boolean(x: Struct) {
1576    let Struct { boolean } = x;
1577}
1578
1579macro_rules! identity { ($body:expr) => { $body } }
1580
1581fn new() -> Struct {
1582    identity!(Struct { boolean: Bool::True })
1583}
1584"#,
1585        )
1586    }
1587
1588    #[test]
1589    fn field_non_bool() {
1590        cov_mark::check!(not_applicable_non_bool_field);
1591        check_assist_not_applicable(
1592            convert_bool_to_enum,
1593            r#"
1594struct Foo {
1595    $0bar: usize,
1596}
1597
1598fn main() {
1599    let foo = Foo { bar: 1 };
1600}
1601"#,
1602        )
1603    }
1604
1605    #[test]
1606    fn const_basic() {
1607        check_assist(
1608            convert_bool_to_enum,
1609            r#"
1610const $0FOO: bool = false;
1611
1612fn main() {
1613    if FOO {
1614        println!("foo");
1615    }
1616}
1617"#,
1618            r#"
1619#[derive(PartialEq, Eq)]
1620enum Bool { True, False }
1621
1622const FOO: Bool = Bool::False;
1623
1624fn main() {
1625    if FOO == Bool::True {
1626        println!("foo");
1627    }
1628}
1629"#,
1630        )
1631    }
1632
1633    #[test]
1634    fn const_in_module() {
1635        check_assist(
1636            convert_bool_to_enum,
1637            r#"
1638fn main() {
1639    if foo::FOO {
1640        println!("foo");
1641    }
1642}
1643
1644mod foo {
1645    pub const $0FOO: bool = true;
1646}
1647"#,
1648            r#"
1649use foo::Bool;
1650
1651fn main() {
1652    if foo::FOO == Bool::True {
1653        println!("foo");
1654    }
1655}
1656
1657mod foo {
1658    #[derive(PartialEq, Eq)]
1659    pub enum Bool { True, False }
1660
1661    pub const FOO: Bool = Bool::True;
1662}
1663"#,
1664        )
1665    }
1666
1667    #[test]
1668    fn const_in_module_with_import() {
1669        check_assist(
1670            convert_bool_to_enum,
1671            r#"
1672fn main() {
1673    use foo::FOO;
1674
1675    if FOO {
1676        println!("foo");
1677    }
1678}
1679
1680mod foo {
1681    pub const $0FOO: bool = true;
1682}
1683"#,
1684            r#"
1685use foo::Bool;
1686
1687fn main() {
1688    use foo::FOO;
1689
1690    if FOO == Bool::True {
1691        println!("foo");
1692    }
1693}
1694
1695mod foo {
1696    #[derive(PartialEq, Eq)]
1697    pub enum Bool { True, False }
1698
1699    pub const FOO: Bool = Bool::True;
1700}
1701"#,
1702        )
1703    }
1704
1705    #[test]
1706    fn const_cross_file() {
1707        check_assist(
1708            convert_bool_to_enum,
1709            r#"
1710//- /main.rs
1711mod foo;
1712
1713fn main() {
1714    if foo::FOO {
1715        println!("foo");
1716    }
1717}
1718
1719//- /foo.rs
1720pub const $0FOO: bool = true;
1721"#,
1722            r#"
1723//- /main.rs
1724use foo::Bool;
1725
1726mod foo;
1727
1728fn main() {
1729    if foo::FOO == Bool::True {
1730        println!("foo");
1731    }
1732}
1733
1734//- /foo.rs
1735#[derive(PartialEq, Eq)]
1736pub enum Bool { True, False }
1737
1738pub const FOO: Bool = Bool::True;
1739"#,
1740        )
1741    }
1742
1743    #[test]
1744    fn const_cross_file_and_module() {
1745        check_assist(
1746            convert_bool_to_enum,
1747            r#"
1748//- /main.rs
1749mod foo;
1750
1751fn main() {
1752    use foo::bar;
1753
1754    if bar::BAR {
1755        println!("foo");
1756    }
1757}
1758
1759//- /foo.rs
1760pub mod bar {
1761    pub const $0BAR: bool = false;
1762}
1763"#,
1764            r#"
1765//- /main.rs
1766use foo::bar::Bool;
1767
1768mod foo;
1769
1770fn main() {
1771    use foo::bar;
1772
1773    if bar::BAR == Bool::True {
1774        println!("foo");
1775    }
1776}
1777
1778//- /foo.rs
1779pub mod bar {
1780    #[derive(PartialEq, Eq)]
1781    pub enum Bool { True, False }
1782
1783    pub const BAR: Bool = Bool::False;
1784}
1785"#,
1786        )
1787    }
1788
1789    #[test]
1790    fn const_in_impl_cross_file() {
1791        check_assist(
1792            convert_bool_to_enum,
1793            r#"
1794//- /main.rs
1795mod foo;
1796
1797struct Foo;
1798
1799impl Foo {
1800    pub const $0BOOL: bool = true;
1801}
1802
1803//- /foo.rs
1804use crate::Foo;
1805
1806fn foo() -> bool {
1807    Foo::BOOL
1808}
1809"#,
1810            r#"
1811//- /main.rs
1812mod foo;
1813
1814struct Foo;
1815
1816#[derive(PartialEq, Eq)]
1817pub enum Bool { True, False }
1818
1819impl Foo {
1820    pub const BOOL: Bool = Bool::True;
1821}
1822
1823//- /foo.rs
1824use crate::{Bool, Foo};
1825
1826fn foo() -> bool {
1827    Foo::BOOL == Bool::True
1828}
1829"#,
1830        )
1831    }
1832
1833    #[test]
1834    fn const_in_trait() {
1835        check_assist(
1836            convert_bool_to_enum,
1837            r#"
1838trait Foo {
1839    const $0BOOL: bool;
1840}
1841
1842impl Foo for usize {
1843    const BOOL: bool = true;
1844}
1845
1846fn main() {
1847    if <usize as Foo>::BOOL {
1848        println!("foo");
1849    }
1850}
1851"#,
1852            r#"
1853#[derive(PartialEq, Eq)]
1854enum Bool { True, False }
1855
1856trait Foo {
1857    const BOOL: Bool;
1858}
1859
1860impl Foo for usize {
1861    const BOOL: Bool = Bool::True;
1862}
1863
1864fn main() {
1865    if <usize as Foo>::BOOL == Bool::True {
1866        println!("foo");
1867    }
1868}
1869"#,
1870        )
1871    }
1872
1873    #[test]
1874    fn const_non_bool() {
1875        cov_mark::check!(not_applicable_non_bool_const);
1876        check_assist_not_applicable(
1877            convert_bool_to_enum,
1878            r#"
1879const $0FOO: &str = "foo";
1880
1881fn main() {
1882    println!("{FOO}");
1883}
1884"#,
1885        )
1886    }
1887
1888    #[test]
1889    fn static_basic() {
1890        check_assist(
1891            convert_bool_to_enum,
1892            r#"
1893static mut $0BOOL: bool = true;
1894
1895fn main() {
1896    unsafe { BOOL = false };
1897    if unsafe { BOOL } {
1898        println!("foo");
1899    }
1900}
1901"#,
1902            r#"
1903#[derive(PartialEq, Eq)]
1904enum Bool { True, False }
1905
1906static mut BOOL: Bool = Bool::True;
1907
1908fn main() {
1909    unsafe { BOOL = Bool::False };
1910    if unsafe { BOOL == Bool::True } {
1911        println!("foo");
1912    }
1913}
1914"#,
1915        )
1916    }
1917
1918    #[test]
1919    fn static_non_bool() {
1920        cov_mark::check!(not_applicable_non_bool_static);
1921        check_assist_not_applicable(
1922            convert_bool_to_enum,
1923            r#"
1924static mut $0FOO: usize = 0;
1925
1926fn main() {
1927    if unsafe { FOO } == 0 {
1928        println!("foo");
1929    }
1930}
1931"#,
1932        )
1933    }
1934
1935    #[test]
1936    fn not_applicable_to_other_names() {
1937        check_assist_not_applicable(convert_bool_to_enum, "fn $0main() {}")
1938    }
1939}