ide_assists/handlers/
convert_bool_to_enum.rs

1use either::Either;
2use hir::ModuleDef;
3use ide_db::text_edit::TextRange;
4use ide_db::{
5    FxHashSet,
6    assists::AssistId,
7    defs::Definition,
8    helpers::mod_path_to_ast,
9    imports::insert_use::{ImportScope, insert_use},
10    search::{FileReference, UsageSearchResult},
11    source_change::SourceChangeBuilder,
12};
13use itertools::Itertools;
14use syntax::{
15    AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
16    ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, make},
17};
18
19use crate::{
20    assist_context::{AssistContext, Assists},
21    utils,
22};
23
24// Assist: convert_bool_to_enum
25//
26// This converts boolean local variables, fields, constants, and statics into a new
27// enum with two variants `Bool::True` and `Bool::False`, as well as replacing
28// all assignments with the variants and replacing all usages with `== Bool::True` or
29// `== Bool::False`.
30//
31// ```
32// fn main() {
33//     let $0bool = true;
34//
35//     if bool {
36//         println!("foo");
37//     }
38// }
39// ```
40// ->
41// ```
42// #[derive(PartialEq, Eq)]
43// enum Bool { True, False }
44//
45// fn main() {
46//     let bool = Bool::True;
47//
48//     if bool == Bool::True {
49//         println!("foo");
50//     }
51// }
52// ```
53pub(crate) fn convert_bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
54    let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
55        find_bool_node(ctx)?;
56    let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db());
57
58    let target = name.syntax().text_range();
59    acc.add(
60        AssistId::refactor_rewrite("convert_bool_to_enum"),
61        "Convert boolean to enum",
62        target,
63        |edit| {
64            if let Some(ty) = &ty_annotation {
65                cov_mark::hit!(replaces_ty_annotation);
66                edit.replace(ty.syntax().text_range(), "Bool");
67            }
68
69            if let Some(initializer) = initializer {
70                replace_bool_expr(edit, initializer);
71            }
72
73            let usages = definition.usages(&ctx.sema).all();
74            add_enum_def(edit, ctx, &usages, target_node, &target_module);
75            let mut delayed_mutations = Vec::new();
76            replace_usages(edit, ctx, usages, definition, &target_module, &mut delayed_mutations);
77            for (scope, path) in delayed_mutations {
78                insert_use(&scope, path, &ctx.config.insert_use);
79            }
80        },
81    )
82}
83
84struct BoolNodeData {
85    target_node: SyntaxNode,
86    name: ast::Name,
87    ty_annotation: Option<ast::Type>,
88    initializer: Option<ast::Expr>,
89    definition: Definition,
90}
91
92/// Attempts to find an appropriate node to apply the action to.
93fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
94    let name = ctx.find_node_at_offset::<ast::Name>()?;
95
96    if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
97        let def = ctx.sema.to_def(&ident_pat)?;
98        if !def.ty(ctx.db()).is_bool() {
99            cov_mark::hit!(not_applicable_non_bool_local);
100            return None;
101        }
102
103        let local_definition = Definition::Local(def);
104        match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
105            Either::Left(param) => Some(BoolNodeData {
106                target_node: param.syntax().clone(),
107                name,
108                ty_annotation: param.ty(),
109                initializer: None,
110                definition: local_definition,
111            }),
112            Either::Right(let_stmt) => Some(BoolNodeData {
113                target_node: let_stmt.syntax().clone(),
114                name,
115                ty_annotation: let_stmt.ty(),
116                initializer: let_stmt.initializer(),
117                definition: local_definition,
118            }),
119        }
120    } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
121        let def = ctx.sema.to_def(&const_)?;
122        if !def.ty(ctx.db()).is_bool() {
123            cov_mark::hit!(not_applicable_non_bool_const);
124            return None;
125        }
126
127        Some(BoolNodeData {
128            target_node: const_.syntax().clone(),
129            name,
130            ty_annotation: const_.ty(),
131            initializer: const_.body(),
132            definition: Definition::Const(def),
133        })
134    } else if let Some(static_) = name.syntax().parent().and_then(ast::Static::cast) {
135        let def = ctx.sema.to_def(&static_)?;
136        if !def.ty(ctx.db()).is_bool() {
137            cov_mark::hit!(not_applicable_non_bool_static);
138            return None;
139        }
140
141        Some(BoolNodeData {
142            target_node: static_.syntax().clone(),
143            name,
144            ty_annotation: static_.ty(),
145            initializer: static_.body(),
146            definition: Definition::Static(def),
147        })
148    } else {
149        let field = name.syntax().parent().and_then(ast::RecordField::cast)?;
150        if field.name()? != name {
151            return None;
152        }
153
154        let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?;
155        let def = ctx.sema.to_def(&field)?;
156        if !def.ty(ctx.db()).is_bool() {
157            cov_mark::hit!(not_applicable_non_bool_field);
158            return None;
159        }
160        Some(BoolNodeData {
161            target_node: adt.syntax().clone(),
162            name,
163            ty_annotation: field.ty(),
164            initializer: None,
165            definition: Definition::Field(def),
166        })
167    }
168}
169
170fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {
171    let expr_range = expr.syntax().text_range();
172    let enum_expr = bool_expr_to_enum_expr(expr);
173    edit.replace(expr_range, enum_expr.syntax().text())
174}
175
176/// Converts an expression of type `bool` to one of the new enum type.
177fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
178    let true_expr = make::expr_path(make::path_from_text("Bool::True"));
179    let false_expr = make::expr_path(make::path_from_text("Bool::False"));
180
181    if let ast::Expr::Literal(literal) = &expr {
182        match literal.kind() {
183            ast::LiteralKind::Bool(true) => true_expr,
184            ast::LiteralKind::Bool(false) => false_expr,
185            _ => expr,
186        }
187    } else {
188        make::expr_if(
189            expr,
190            make::tail_only_block_expr(true_expr),
191            Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))),
192        )
193        .into()
194    }
195}
196
197/// Replaces all usages of the target identifier, both when read and written to.
198fn replace_usages(
199    edit: &mut SourceChangeBuilder,
200    ctx: &AssistContext<'_>,
201    usages: UsageSearchResult,
202    target_definition: Definition,
203    target_module: &hir::Module,
204    delayed_mutations: &mut Vec<(ImportScope, ast::Path)>,
205) {
206    for (file_id, references) in usages {
207        edit.edit_file(file_id.file_id(ctx.db()));
208
209        let refs_with_imports = augment_references_with_imports(ctx, references, target_module);
210
211        refs_with_imports.into_iter().rev().for_each(
212            |FileReferenceWithImport { range, name, import_data }| {
213                // replace the usages in patterns and expressions
214                if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
215                    cov_mark::hit!(replaces_record_pat_shorthand);
216
217                    let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
218                    if let Some(def) = definition {
219                        replace_usages(
220                            edit,
221                            ctx,
222                            def.usages(&ctx.sema).all(),
223                            target_definition,
224                            target_module,
225                            delayed_mutations,
226                        )
227                    }
228                } else if let Some(initializer) = find_assignment_usage(&name) {
229                    cov_mark::hit!(replaces_assignment);
230
231                    replace_bool_expr(edit, initializer);
232                } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
233                    cov_mark::hit!(replaces_negation);
234
235                    edit.replace(
236                        prefix_expr.syntax().text_range(),
237                        format!("{inner_expr} == Bool::False"),
238                    );
239                } else if let Some((record_field, initializer)) = name
240                    .as_name_ref()
241                    .and_then(ast::RecordExprField::for_field_name)
242                    .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
243                    .and_then(|(got_field, _, _)| {
244                        find_record_expr_usage(&name, got_field, target_definition)
245                    })
246                {
247                    cov_mark::hit!(replaces_record_expr);
248
249                    let enum_expr = bool_expr_to_enum_expr(initializer);
250                    utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
251                } else if let Some(pat) = find_record_pat_field_usage(&name) {
252                    match pat {
253                        ast::Pat::IdentPat(ident_pat) => {
254                            cov_mark::hit!(replaces_record_pat);
255
256                            let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
257                            if let Some(def) = definition {
258                                replace_usages(
259                                    edit,
260                                    ctx,
261                                    def.usages(&ctx.sema).all(),
262                                    target_definition,
263                                    target_module,
264                                    delayed_mutations,
265                                )
266                            }
267                        }
268                        ast::Pat::LiteralPat(literal_pat) => {
269                            cov_mark::hit!(replaces_literal_pat);
270
271                            if let Some(expr) = literal_pat.literal().and_then(|literal| {
272                                literal.syntax().ancestors().find_map(ast::Expr::cast)
273                            }) {
274                                replace_bool_expr(edit, expr);
275                            }
276                        }
277                        _ => (),
278                    }
279                } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
280                    edit.replace(ty_annotation.syntax().text_range(), "Bool");
281                    replace_bool_expr(edit, initializer);
282                } else if let Some(receiver) = find_method_call_expr_usage(&name) {
283                    edit.replace(
284                        receiver.syntax().text_range(),
285                        format!("({receiver} == Bool::True)"),
286                    );
287                } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
288                    // for any other usage in an expression, replace it with a check that it is the true variant
289                    if let Some((record_field, expr)) =
290                        name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
291                            |record_field| record_field.expr().map(|expr| (record_field, expr)),
292                        )
293                    {
294                        utils::replace_record_field_expr(
295                            ctx,
296                            edit,
297                            record_field,
298                            make::expr_bin_op(
299                                expr,
300                                ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
301                                make::expr_path(make::path_from_text("Bool::True")),
302                            ),
303                        );
304                    } else {
305                        edit.replace(range, format!("{} == Bool::True", name.text()));
306                    }
307                }
308
309                // add imports across modules where needed
310                if let Some((scope, path)) = import_data {
311                    let scope = edit.make_import_scope_mut(scope);
312                    delayed_mutations.push((scope, path));
313                }
314            },
315        )
316    }
317}
318
319struct FileReferenceWithImport {
320    range: TextRange,
321    name: ast::NameLike,
322    import_data: Option<(ImportScope, ast::Path)>,
323}
324
325fn augment_references_with_imports(
326    ctx: &AssistContext<'_>,
327    references: Vec<FileReference>,
328    target_module: &hir::Module,
329) -> Vec<FileReferenceWithImport> {
330    let mut visited_modules = FxHashSet::default();
331
332    let edition = target_module.krate(ctx.db()).edition(ctx.db());
333    references
334        .into_iter()
335        .filter_map(|FileReference { range, name, .. }| {
336            let name = name.into_name_like()?;
337            ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
338        })
339        .map(|(range, name, ref_module)| {
340            // if the referenced module is not the same as the target one and has not been seen before, add an import
341            let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
342                && !visited_modules.contains(&ref_module)
343            {
344                visited_modules.insert(ref_module);
345
346                ImportScope::find_insert_use_container(name.syntax(), &ctx.sema).and_then(
347                    |import_scope| {
348                        let cfg = ctx.config.find_path_config(
349                            ctx.sema.is_nightly(target_module.krate(ctx.sema.db)),
350                        );
351                        let path = ref_module
352                            .find_use_path(
353                                ctx.sema.db,
354                                ModuleDef::Module(*target_module),
355                                ctx.config.insert_use.prefix_kind,
356                                cfg,
357                            )
358                            .map(|mod_path| {
359                                make::path_concat(
360                                    mod_path_to_ast(&mod_path, edition),
361                                    make::path_from_text("Bool"),
362                                )
363                            })?;
364
365                        Some((import_scope, path))
366                    },
367                )
368            } else {
369                None
370            };
371
372            FileReferenceWithImport { range, name, import_data }
373        })
374        .collect()
375}
376
377fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> {
378    let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?;
379
380    if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) {
381        cov_mark::hit!(dont_assign_incorrect_ref);
382        return None;
383    }
384
385    if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
386        bin_expr.rhs()
387    } else {
388        None
389    }
390}
391
392fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> {
393    let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
394
395    if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
396        cov_mark::hit!(dont_overwrite_expression_inside_negation);
397        return None;
398    }
399
400    if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
401        let inner_expr = prefix_expr.expr()?;
402        Some((prefix_expr, inner_expr))
403    } else {
404        None
405    }
406}
407
408fn find_record_expr_usage(
409    name: &ast::NameLike,
410    got_field: hir::Field,
411    target_definition: Definition,
412) -> Option<(ast::RecordExprField, ast::Expr)> {
413    let name_ref = name.as_name_ref()?;
414    let record_field = ast::RecordExprField::for_field_name(name_ref)?;
415    let initializer = record_field.expr()?;
416
417    match target_definition {
418        Definition::Field(expected_field) if got_field == expected_field => {
419            Some((record_field, initializer))
420        }
421        _ => None,
422    }
423}
424
425fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
426    let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?;
427    let pat = record_pat_field.pat()?;
428
429    match pat {
430        ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat),
431        _ => None,
432    }
433}
434
435fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
436    let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
437    const_.syntax().parent().and_then(ast::AssocItemList::cast)?;
438
439    Some((const_.ty()?, const_.body()?))
440}
441
442fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
443    let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
444    let receiver = method_call.receiver()?;
445
446    if !receiver.syntax().descendants().contains(name.syntax()) {
447        return None;
448    }
449
450    Some(receiver)
451}
452
453/// Adds the definition of the new enum before the target node.
454fn add_enum_def(
455    edit: &mut SourceChangeBuilder,
456    ctx: &AssistContext<'_>,
457    usages: &UsageSearchResult,
458    target_node: SyntaxNode,
459    target_module: &hir::Module,
460) -> Option<()> {
461    let insert_before = node_to_insert_before(target_node);
462
463    if ctx
464        .sema
465        .scope(&insert_before)?
466        .module()
467        .scope(ctx.db(), Some(*target_module))
468        .iter()
469        .any(|(name, _)| name.as_str() == "Bool")
470    {
471        return None;
472    }
473
474    let make_enum_pub = usages
475        .iter()
476        .flat_map(|(_, refs)| refs)
477        .filter_map(|FileReference { name, .. }| {
478            let name = name.clone().into_name_like()?;
479            ctx.sema.scope(name.syntax()).map(|scope| scope.module())
480        })
481        .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
482    let enum_def = make_bool_enum(make_enum_pub);
483
484    let indent = IndentLevel::from_node(&insert_before);
485    enum_def.reindent_to(indent);
486
487    edit.insert(
488        insert_before.text_range().start(),
489        format!("{}\n\n{indent}", enum_def.syntax().text()),
490    );
491
492    Some(())
493}
494
495/// Finds where to put the new enum definition.
496/// Tries to find the ast node at the nearest module or at top-level, otherwise just
497/// returns the input node.
498fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
499    target_node
500        .ancestors()
501        .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
502        .filter(|it| ast::Item::can_cast(it.kind()))
503        .last()
504        .unwrap_or(target_node)
505}
506
507fn make_bool_enum(make_pub: bool) -> ast::Enum {
508    let derive_eq = make::attr_outer(make::meta_token_tree(
509        make::ext::ident_path("derive"),
510        make::token_tree(
511            T!['('],
512            vec![
513                NodeOrToken::Token(make::tokens::ident("PartialEq")),
514                NodeOrToken::Token(make::token(T![,])),
515                NodeOrToken::Token(make::tokens::single_space()),
516                NodeOrToken::Token(make::tokens::ident("Eq")),
517            ],
518        ),
519    ));
520    make::enum_(
521        [derive_eq],
522        if make_pub { Some(make::visibility_pub()) } else { None },
523        make::name("Bool"),
524        None,
525        None,
526        make::variant_list(vec![
527            make::variant(None, make::name("True"), None, None),
528            make::variant(None, make::name("False"), None, None),
529        ]),
530    )
531    .clone_for_update()
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    use crate::tests::{check_assist, check_assist_not_applicable};
539
540    #[test]
541    fn parameter_with_first_param_usage() {
542        check_assist(
543            convert_bool_to_enum,
544            r#"
545fn function($0foo: bool, bar: bool) {
546    if foo {
547        println!("foo");
548    }
549}
550"#,
551            r#"
552#[derive(PartialEq, Eq)]
553enum Bool { True, False }
554
555fn function(foo: Bool, bar: bool) {
556    if foo == Bool::True {
557        println!("foo");
558    }
559}
560"#,
561        )
562    }
563
564    #[test]
565    fn no_duplicate_enums() {
566        check_assist(
567            convert_bool_to_enum,
568            r#"
569#[derive(PartialEq, Eq)]
570enum Bool { True, False }
571
572fn function(foo: bool, $0bar: bool) {
573    if bar {
574        println!("bar");
575    }
576}
577"#,
578            r#"
579#[derive(PartialEq, Eq)]
580enum Bool { True, False }
581
582fn function(foo: bool, bar: Bool) {
583    if bar == Bool::True {
584        println!("bar");
585    }
586}
587"#,
588        )
589    }
590
591    #[test]
592    fn parameter_with_last_param_usage() {
593        check_assist(
594            convert_bool_to_enum,
595            r#"
596fn function(foo: bool, $0bar: bool) {
597    if bar {
598        println!("bar");
599    }
600}
601"#,
602            r#"
603#[derive(PartialEq, Eq)]
604enum Bool { True, False }
605
606fn function(foo: bool, bar: Bool) {
607    if bar == Bool::True {
608        println!("bar");
609    }
610}
611"#,
612        )
613    }
614
615    #[test]
616    fn parameter_with_middle_param_usage() {
617        check_assist(
618            convert_bool_to_enum,
619            r#"
620fn function(foo: bool, $0bar: bool, baz: bool) {
621    if bar {
622        println!("bar");
623    }
624}
625"#,
626            r#"
627#[derive(PartialEq, Eq)]
628enum Bool { True, False }
629
630fn function(foo: bool, bar: Bool, baz: bool) {
631    if bar == Bool::True {
632        println!("bar");
633    }
634}
635"#,
636        )
637    }
638
639    #[test]
640    fn parameter_with_closure_usage() {
641        check_assist(
642            convert_bool_to_enum,
643            r#"
644fn main() {
645    let foo = |$0bar: bool| bar;
646}
647"#,
648            r#"
649#[derive(PartialEq, Eq)]
650enum Bool { True, False }
651
652fn main() {
653    let foo = |bar: Bool| bar == Bool::True;
654}
655"#,
656        )
657    }
658
659    #[test]
660    fn local_variable_with_usage() {
661        check_assist(
662            convert_bool_to_enum,
663            r#"
664fn main() {
665    let $0foo = true;
666
667    if foo {
668        println!("foo");
669    }
670}
671"#,
672            r#"
673#[derive(PartialEq, Eq)]
674enum Bool { True, False }
675
676fn main() {
677    let foo = Bool::True;
678
679    if foo == Bool::True {
680        println!("foo");
681    }
682}
683"#,
684        )
685    }
686
687    #[test]
688    fn local_variable_with_usage_negated() {
689        cov_mark::check!(replaces_negation);
690        check_assist(
691            convert_bool_to_enum,
692            r#"
693fn main() {
694    let $0foo = true;
695
696    if !foo {
697        println!("foo");
698    }
699}
700"#,
701            r#"
702#[derive(PartialEq, Eq)]
703enum Bool { True, False }
704
705fn main() {
706    let foo = Bool::True;
707
708    if foo == Bool::False {
709        println!("foo");
710    }
711}
712"#,
713        )
714    }
715
716    #[test]
717    fn local_variable_with_type_annotation() {
718        cov_mark::check!(replaces_ty_annotation);
719        check_assist(
720            convert_bool_to_enum,
721            r#"
722fn main() {
723    let $0foo: bool = false;
724}
725"#,
726            r#"
727#[derive(PartialEq, Eq)]
728enum Bool { True, False }
729
730fn main() {
731    let foo: Bool = Bool::False;
732}
733"#,
734        )
735    }
736
737    #[test]
738    fn local_variable_with_non_literal_initializer() {
739        check_assist(
740            convert_bool_to_enum,
741            r#"
742fn main() {
743    let $0foo = 1 == 2;
744}
745"#,
746            r#"
747#[derive(PartialEq, Eq)]
748enum Bool { True, False }
749
750fn main() {
751    let foo = if 1 == 2 { Bool::True } else { Bool::False };
752}
753"#,
754        )
755    }
756
757    #[test]
758    fn local_variable_binexpr_usage() {
759        check_assist(
760            convert_bool_to_enum,
761            r#"
762fn main() {
763    let $0foo = false;
764    let bar = true;
765
766    if !foo && bar {
767        println!("foobar");
768    }
769}
770"#,
771            r#"
772#[derive(PartialEq, Eq)]
773enum Bool { True, False }
774
775fn main() {
776    let foo = Bool::False;
777    let bar = true;
778
779    if foo == Bool::False && bar {
780        println!("foobar");
781    }
782}
783"#,
784        )
785    }
786
787    #[test]
788    fn local_variable_unop_usage() {
789        check_assist(
790            convert_bool_to_enum,
791            r#"
792fn main() {
793    let $0foo = true;
794
795    if *&foo {
796        println!("foobar");
797    }
798}
799"#,
800            r#"
801#[derive(PartialEq, Eq)]
802enum Bool { True, False }
803
804fn main() {
805    let foo = Bool::True;
806
807    if *&foo == Bool::True {
808        println!("foobar");
809    }
810}
811"#,
812        )
813    }
814
815    #[test]
816    fn local_variable_assigned_later() {
817        cov_mark::check!(replaces_assignment);
818        check_assist(
819            convert_bool_to_enum,
820            r#"
821fn main() {
822    let $0foo: bool;
823    foo = true;
824}
825"#,
826            r#"
827#[derive(PartialEq, Eq)]
828enum Bool { True, False }
829
830fn main() {
831    let foo: Bool;
832    foo = Bool::True;
833}
834"#,
835        )
836    }
837
838    #[test]
839    fn local_variable_does_not_apply_recursively() {
840        check_assist(
841            convert_bool_to_enum,
842            r#"
843fn main() {
844    let $0foo = true;
845    let bar = !foo;
846
847    if bar {
848        println!("bar");
849    }
850}
851"#,
852            r#"
853#[derive(PartialEq, Eq)]
854enum Bool { True, False }
855
856fn main() {
857    let foo = Bool::True;
858    let bar = foo == Bool::False;
859
860    if bar {
861        println!("bar");
862    }
863}
864"#,
865        )
866    }
867
868    #[test]
869    fn local_variable_nested_in_negation() {
870        cov_mark::check!(dont_overwrite_expression_inside_negation);
871        check_assist(
872            convert_bool_to_enum,
873            r#"
874fn main() {
875    if !"foo".chars().any(|c| {
876        let $0foo = true;
877        foo
878    }) {
879        println!("foo");
880    }
881}
882"#,
883            r#"
884#[derive(PartialEq, Eq)]
885enum Bool { True, False }
886
887fn main() {
888    if !"foo".chars().any(|c| {
889        let foo = Bool::True;
890        foo == Bool::True
891    }) {
892        println!("foo");
893    }
894}
895"#,
896        )
897    }
898
899    #[test]
900    fn local_variable_non_bool() {
901        cov_mark::check!(not_applicable_non_bool_local);
902        check_assist_not_applicable(
903            convert_bool_to_enum,
904            r#"
905fn main() {
906    let $0foo = 1;
907}
908"#,
909        )
910    }
911
912    #[test]
913    fn local_variable_cursor_not_on_ident() {
914        check_assist_not_applicable(
915            convert_bool_to_enum,
916            r#"
917fn main() {
918    let foo = $0true;
919}
920"#,
921        )
922    }
923
924    #[test]
925    fn local_variable_non_ident_pat() {
926        check_assist_not_applicable(
927            convert_bool_to_enum,
928            r#"
929fn main() {
930    let ($0foo, bar) = (true, false);
931}
932"#,
933        )
934    }
935
936    #[test]
937    fn local_var_init_struct_usage() {
938        check_assist(
939            convert_bool_to_enum,
940            r#"
941struct Foo {
942    foo: bool,
943}
944
945fn main() {
946    let $0foo = true;
947    let s = Foo { foo };
948}
949"#,
950            r#"
951struct Foo {
952    foo: bool,
953}
954
955#[derive(PartialEq, Eq)]
956enum Bool { True, False }
957
958fn main() {
959    let foo = Bool::True;
960    let s = Foo { foo: foo == Bool::True };
961}
962"#,
963        )
964    }
965
966    #[test]
967    fn local_var_init_struct_usage_in_macro() {
968        check_assist(
969            convert_bool_to_enum,
970            r#"
971struct Struct {
972    boolean: bool,
973}
974
975macro_rules! identity {
976    ($body:expr) => {
977        $body
978    }
979}
980
981fn new() -> Struct {
982    let $0boolean = true;
983    identity![Struct { boolean }]
984}
985"#,
986            r#"
987struct Struct {
988    boolean: bool,
989}
990
991macro_rules! identity {
992    ($body:expr) => {
993        $body
994    }
995}
996
997#[derive(PartialEq, Eq)]
998enum Bool { True, False }
999
1000fn new() -> Struct {
1001    let boolean = Bool::True;
1002    identity![Struct { boolean: boolean == Bool::True }]
1003}
1004"#,
1005        )
1006    }
1007
1008    #[test]
1009    fn field_struct_basic() {
1010        cov_mark::check!(replaces_record_expr);
1011        check_assist(
1012            convert_bool_to_enum,
1013            r#"
1014struct Foo {
1015    $0bar: bool,
1016    baz: bool,
1017}
1018
1019fn main() {
1020    let foo = Foo { bar: true, baz: false };
1021
1022    if foo.bar {
1023        println!("foo");
1024    }
1025}
1026"#,
1027            r#"
1028#[derive(PartialEq, Eq)]
1029enum Bool { True, False }
1030
1031struct Foo {
1032    bar: Bool,
1033    baz: bool,
1034}
1035
1036fn main() {
1037    let foo = Foo { bar: Bool::True, baz: false };
1038
1039    if foo.bar == Bool::True {
1040        println!("foo");
1041    }
1042}
1043"#,
1044        )
1045    }
1046
1047    #[test]
1048    fn field_enum_basic() {
1049        cov_mark::check!(replaces_record_pat);
1050        check_assist(
1051            convert_bool_to_enum,
1052            r#"
1053enum Foo {
1054    Foo,
1055    Bar { $0bar: bool },
1056}
1057
1058fn main() {
1059    let foo = Foo::Bar { bar: true };
1060
1061    if let Foo::Bar { bar: baz } = foo {
1062        if baz {
1063            println!("foo");
1064        }
1065    }
1066}
1067"#,
1068            r#"
1069#[derive(PartialEq, Eq)]
1070enum Bool { True, False }
1071
1072enum Foo {
1073    Foo,
1074    Bar { bar: Bool },
1075}
1076
1077fn main() {
1078    let foo = Foo::Bar { bar: Bool::True };
1079
1080    if let Foo::Bar { bar: baz } = foo {
1081        if baz == Bool::True {
1082            println!("foo");
1083        }
1084    }
1085}
1086"#,
1087        )
1088    }
1089
1090    #[test]
1091    fn field_enum_cross_file() {
1092        // FIXME: The import is missing
1093        check_assist(
1094            convert_bool_to_enum,
1095            r#"
1096//- /foo.rs
1097pub enum Foo {
1098    Foo,
1099    Bar { $0bar: bool },
1100}
1101
1102fn foo() {
1103    let foo = Foo::Bar { bar: true };
1104}
1105
1106//- /main.rs
1107use foo::Foo;
1108
1109mod foo;
1110
1111fn main() {
1112    let foo = Foo::Bar { bar: false };
1113}
1114"#,
1115            r#"
1116//- /foo.rs
1117#[derive(PartialEq, Eq)]
1118pub enum Bool { True, False }
1119
1120pub enum Foo {
1121    Foo,
1122    Bar { bar: Bool },
1123}
1124
1125fn foo() {
1126    let foo = Foo::Bar { bar: Bool::True };
1127}
1128
1129//- /main.rs
1130use foo::{Bool, Foo};
1131
1132mod foo;
1133
1134fn main() {
1135    let foo = Foo::Bar { bar: Bool::False };
1136}
1137"#,
1138        )
1139    }
1140
1141    #[test]
1142    fn field_enum_shorthand() {
1143        cov_mark::check!(replaces_record_pat_shorthand);
1144        check_assist(
1145            convert_bool_to_enum,
1146            r#"
1147enum Foo {
1148    Foo,
1149    Bar { $0bar: bool },
1150}
1151
1152fn main() {
1153    let foo = Foo::Bar { bar: true };
1154
1155    match foo {
1156        Foo::Bar { bar } => {
1157            if bar {
1158                println!("foo");
1159            }
1160        }
1161        _ => (),
1162    }
1163}
1164"#,
1165            r#"
1166#[derive(PartialEq, Eq)]
1167enum Bool { True, False }
1168
1169enum Foo {
1170    Foo,
1171    Bar { bar: Bool },
1172}
1173
1174fn main() {
1175    let foo = Foo::Bar { bar: Bool::True };
1176
1177    match foo {
1178        Foo::Bar { bar } => {
1179            if bar == Bool::True {
1180                println!("foo");
1181            }
1182        }
1183        _ => (),
1184    }
1185}
1186"#,
1187        )
1188    }
1189
1190    #[test]
1191    fn field_enum_replaces_literal_patterns() {
1192        cov_mark::check!(replaces_literal_pat);
1193        check_assist(
1194            convert_bool_to_enum,
1195            r#"
1196enum Foo {
1197    Foo,
1198    Bar { $0bar: bool },
1199}
1200
1201fn main() {
1202    let foo = Foo::Bar { bar: true };
1203
1204    if let Foo::Bar { bar: true } = foo {
1205        println!("foo");
1206    }
1207}
1208"#,
1209            r#"
1210#[derive(PartialEq, Eq)]
1211enum Bool { True, False }
1212
1213enum Foo {
1214    Foo,
1215    Bar { bar: Bool },
1216}
1217
1218fn main() {
1219    let foo = Foo::Bar { bar: Bool::True };
1220
1221    if let Foo::Bar { bar: Bool::True } = foo {
1222        println!("foo");
1223    }
1224}
1225"#,
1226        )
1227    }
1228
1229    #[test]
1230    fn field_enum_keeps_wildcard_patterns() {
1231        check_assist(
1232            convert_bool_to_enum,
1233            r#"
1234enum Foo {
1235    Foo,
1236    Bar { $0bar: bool },
1237}
1238
1239fn main() {
1240    let foo = Foo::Bar { bar: true };
1241
1242    if let Foo::Bar { bar: _ } = foo {
1243        println!("foo");
1244    }
1245}
1246"#,
1247            r#"
1248#[derive(PartialEq, Eq)]
1249enum Bool { True, False }
1250
1251enum Foo {
1252    Foo,
1253    Bar { bar: Bool },
1254}
1255
1256fn main() {
1257    let foo = Foo::Bar { bar: Bool::True };
1258
1259    if let Foo::Bar { bar: _ } = foo {
1260        println!("foo");
1261    }
1262}
1263"#,
1264        )
1265    }
1266
1267    #[test]
1268    fn field_union_basic() {
1269        check_assist(
1270            convert_bool_to_enum,
1271            r#"
1272union Foo {
1273    $0foo: bool,
1274    bar: usize,
1275}
1276
1277fn main() {
1278    let foo = Foo { foo: true };
1279
1280    if unsafe { foo.foo } {
1281        println!("foo");
1282    }
1283}
1284"#,
1285            r#"
1286#[derive(PartialEq, Eq)]
1287enum Bool { True, False }
1288
1289union Foo {
1290    foo: Bool,
1291    bar: usize,
1292}
1293
1294fn main() {
1295    let foo = Foo { foo: Bool::True };
1296
1297    if unsafe { foo.foo == Bool::True } {
1298        println!("foo");
1299    }
1300}
1301"#,
1302        )
1303    }
1304
1305    #[test]
1306    fn field_negated() {
1307        check_assist(
1308            convert_bool_to_enum,
1309            r#"
1310struct Foo {
1311    $0bar: bool,
1312}
1313
1314fn main() {
1315    let foo = Foo { bar: false };
1316
1317    if !foo.bar {
1318        println!("foo");
1319    }
1320}
1321"#,
1322            r#"
1323#[derive(PartialEq, Eq)]
1324enum Bool { True, False }
1325
1326struct Foo {
1327    bar: Bool,
1328}
1329
1330fn main() {
1331    let foo = Foo { bar: Bool::False };
1332
1333    if foo.bar == Bool::False {
1334        println!("foo");
1335    }
1336}
1337"#,
1338        )
1339    }
1340
1341    #[test]
1342    fn field_in_mod_properly_indented() {
1343        check_assist(
1344            convert_bool_to_enum,
1345            r#"
1346mod foo {
1347    struct Bar {
1348        $0baz: bool,
1349    }
1350
1351    impl Bar {
1352        fn new(baz: bool) -> Self {
1353            Self { baz }
1354        }
1355    }
1356}
1357"#,
1358            r#"
1359mod foo {
1360    #[derive(PartialEq, Eq)]
1361    enum Bool { True, False }
1362
1363    struct Bar {
1364        baz: Bool,
1365    }
1366
1367    impl Bar {
1368        fn new(baz: bool) -> Self {
1369            Self { baz: if baz { Bool::True } else { Bool::False } }
1370        }
1371    }
1372}
1373"#,
1374        )
1375    }
1376
1377    #[test]
1378    fn field_multiple_initializations() {
1379        check_assist(
1380            convert_bool_to_enum,
1381            r#"
1382struct Foo {
1383    $0bar: bool,
1384    baz: bool,
1385}
1386
1387fn main() {
1388    let foo1 = Foo { bar: true, baz: false };
1389    let foo2 = Foo { bar: false, baz: false };
1390
1391    if foo1.bar && foo2.bar {
1392        println!("foo");
1393    }
1394}
1395"#,
1396            r#"
1397#[derive(PartialEq, Eq)]
1398enum Bool { True, False }
1399
1400struct Foo {
1401    bar: Bool,
1402    baz: bool,
1403}
1404
1405fn main() {
1406    let foo1 = Foo { bar: Bool::True, baz: false };
1407    let foo2 = Foo { bar: Bool::False, baz: false };
1408
1409    if foo1.bar == Bool::True && foo2.bar == Bool::True {
1410        println!("foo");
1411    }
1412}
1413"#,
1414        )
1415    }
1416
1417    #[test]
1418    fn field_assigned_to_another() {
1419        cov_mark::check!(dont_assign_incorrect_ref);
1420        check_assist(
1421            convert_bool_to_enum,
1422            r#"
1423struct Foo {
1424    $0foo: bool,
1425}
1426
1427struct Bar {
1428    bar: bool,
1429}
1430
1431fn main() {
1432    let foo = Foo { foo: true };
1433    let mut bar = Bar { bar: true };
1434
1435    bar.bar = foo.foo;
1436}
1437"#,
1438            r#"
1439#[derive(PartialEq, Eq)]
1440enum Bool { True, False }
1441
1442struct Foo {
1443    foo: Bool,
1444}
1445
1446struct Bar {
1447    bar: bool,
1448}
1449
1450fn main() {
1451    let foo = Foo { foo: Bool::True };
1452    let mut bar = Bar { bar: true };
1453
1454    bar.bar = foo.foo == Bool::True;
1455}
1456"#,
1457        )
1458    }
1459
1460    #[test]
1461    fn field_initialized_with_other() {
1462        check_assist(
1463            convert_bool_to_enum,
1464            r#"
1465struct Foo {
1466    $0foo: bool,
1467}
1468
1469struct Bar {
1470    bar: bool,
1471}
1472
1473fn main() {
1474    let foo = Foo { foo: true };
1475    let bar = Bar { bar: foo.foo };
1476}
1477"#,
1478            r#"
1479#[derive(PartialEq, Eq)]
1480enum Bool { True, False }
1481
1482struct Foo {
1483    foo: Bool,
1484}
1485
1486struct Bar {
1487    bar: bool,
1488}
1489
1490fn main() {
1491    let foo = Foo { foo: Bool::True };
1492    let bar = Bar { bar: foo.foo == Bool::True };
1493}
1494"#,
1495        )
1496    }
1497
1498    #[test]
1499    fn field_method_chain_usage() {
1500        check_assist(
1501            convert_bool_to_enum,
1502            r#"
1503struct Foo {
1504    $0bool: bool,
1505}
1506
1507fn main() {
1508    let foo = Foo { bool: true };
1509
1510    foo.bool.then(|| 2);
1511}
1512"#,
1513            r#"
1514#[derive(PartialEq, Eq)]
1515enum Bool { True, False }
1516
1517struct Foo {
1518    bool: Bool,
1519}
1520
1521fn main() {
1522    let foo = Foo { bool: Bool::True };
1523
1524    (foo.bool == Bool::True).then(|| 2);
1525}
1526"#,
1527        )
1528    }
1529
1530    #[test]
1531    fn field_in_macro() {
1532        check_assist(
1533            convert_bool_to_enum,
1534            r#"
1535struct Struct {
1536    $0boolean: bool,
1537}
1538
1539fn boolean(x: Struct) {
1540    let Struct { boolean } = x;
1541}
1542
1543macro_rules! identity { ($body:expr) => { $body } }
1544
1545fn new() -> Struct {
1546    identity!(Struct { boolean: true })
1547}
1548"#,
1549            r#"
1550#[derive(PartialEq, Eq)]
1551enum Bool { True, False }
1552
1553struct Struct {
1554    boolean: Bool,
1555}
1556
1557fn boolean(x: Struct) {
1558    let Struct { boolean } = x;
1559}
1560
1561macro_rules! identity { ($body:expr) => { $body } }
1562
1563fn new() -> Struct {
1564    identity!(Struct { boolean: Bool::True })
1565}
1566"#,
1567        )
1568    }
1569
1570    #[test]
1571    fn field_non_bool() {
1572        cov_mark::check!(not_applicable_non_bool_field);
1573        check_assist_not_applicable(
1574            convert_bool_to_enum,
1575            r#"
1576struct Foo {
1577    $0bar: usize,
1578}
1579
1580fn main() {
1581    let foo = Foo { bar: 1 };
1582}
1583"#,
1584        )
1585    }
1586
1587    #[test]
1588    fn const_basic() {
1589        check_assist(
1590            convert_bool_to_enum,
1591            r#"
1592const $0FOO: bool = false;
1593
1594fn main() {
1595    if FOO {
1596        println!("foo");
1597    }
1598}
1599"#,
1600            r#"
1601#[derive(PartialEq, Eq)]
1602enum Bool { True, False }
1603
1604const FOO: Bool = Bool::False;
1605
1606fn main() {
1607    if FOO == Bool::True {
1608        println!("foo");
1609    }
1610}
1611"#,
1612        )
1613    }
1614
1615    #[test]
1616    fn const_in_module() {
1617        check_assist(
1618            convert_bool_to_enum,
1619            r#"
1620fn main() {
1621    if foo::FOO {
1622        println!("foo");
1623    }
1624}
1625
1626mod foo {
1627    pub const $0FOO: bool = true;
1628}
1629"#,
1630            r#"
1631use foo::Bool;
1632
1633fn main() {
1634    if foo::FOO == Bool::True {
1635        println!("foo");
1636    }
1637}
1638
1639mod foo {
1640    #[derive(PartialEq, Eq)]
1641    pub enum Bool { True, False }
1642
1643    pub const FOO: Bool = Bool::True;
1644}
1645"#,
1646        )
1647    }
1648
1649    #[test]
1650    fn const_in_module_with_import() {
1651        check_assist(
1652            convert_bool_to_enum,
1653            r#"
1654fn main() {
1655    use foo::FOO;
1656
1657    if FOO {
1658        println!("foo");
1659    }
1660}
1661
1662mod foo {
1663    pub const $0FOO: bool = true;
1664}
1665"#,
1666            r#"
1667use foo::Bool;
1668
1669fn main() {
1670    use foo::FOO;
1671
1672    if FOO == Bool::True {
1673        println!("foo");
1674    }
1675}
1676
1677mod foo {
1678    #[derive(PartialEq, Eq)]
1679    pub enum Bool { True, False }
1680
1681    pub const FOO: Bool = Bool::True;
1682}
1683"#,
1684        )
1685    }
1686
1687    #[test]
1688    fn const_cross_file() {
1689        check_assist(
1690            convert_bool_to_enum,
1691            r#"
1692//- /main.rs
1693mod foo;
1694
1695fn main() {
1696    if foo::FOO {
1697        println!("foo");
1698    }
1699}
1700
1701//- /foo.rs
1702pub const $0FOO: bool = true;
1703"#,
1704            r#"
1705//- /main.rs
1706use foo::Bool;
1707
1708mod foo;
1709
1710fn main() {
1711    if foo::FOO == Bool::True {
1712        println!("foo");
1713    }
1714}
1715
1716//- /foo.rs
1717#[derive(PartialEq, Eq)]
1718pub enum Bool { True, False }
1719
1720pub const FOO: Bool = Bool::True;
1721"#,
1722        )
1723    }
1724
1725    #[test]
1726    fn const_cross_file_and_module() {
1727        check_assist(
1728            convert_bool_to_enum,
1729            r#"
1730//- /main.rs
1731mod foo;
1732
1733fn main() {
1734    use foo::bar;
1735
1736    if bar::BAR {
1737        println!("foo");
1738    }
1739}
1740
1741//- /foo.rs
1742pub mod bar {
1743    pub const $0BAR: bool = false;
1744}
1745"#,
1746            r#"
1747//- /main.rs
1748use foo::bar::Bool;
1749
1750mod foo;
1751
1752fn main() {
1753    use foo::bar;
1754
1755    if bar::BAR == Bool::True {
1756        println!("foo");
1757    }
1758}
1759
1760//- /foo.rs
1761pub mod bar {
1762    #[derive(PartialEq, Eq)]
1763    pub enum Bool { True, False }
1764
1765    pub const BAR: Bool = Bool::False;
1766}
1767"#,
1768        )
1769    }
1770
1771    #[test]
1772    fn const_in_impl_cross_file() {
1773        check_assist(
1774            convert_bool_to_enum,
1775            r#"
1776//- /main.rs
1777mod foo;
1778
1779struct Foo;
1780
1781impl Foo {
1782    pub const $0BOOL: bool = true;
1783}
1784
1785//- /foo.rs
1786use crate::Foo;
1787
1788fn foo() -> bool {
1789    Foo::BOOL
1790}
1791"#,
1792            r#"
1793//- /main.rs
1794mod foo;
1795
1796struct Foo;
1797
1798#[derive(PartialEq, Eq)]
1799pub enum Bool { True, False }
1800
1801impl Foo {
1802    pub const BOOL: Bool = Bool::True;
1803}
1804
1805//- /foo.rs
1806use crate::{Bool, Foo};
1807
1808fn foo() -> bool {
1809    Foo::BOOL == Bool::True
1810}
1811"#,
1812        )
1813    }
1814
1815    #[test]
1816    fn const_in_trait() {
1817        check_assist(
1818            convert_bool_to_enum,
1819            r#"
1820trait Foo {
1821    const $0BOOL: bool;
1822}
1823
1824impl Foo for usize {
1825    const BOOL: bool = true;
1826}
1827
1828fn main() {
1829    if <usize as Foo>::BOOL {
1830        println!("foo");
1831    }
1832}
1833"#,
1834            r#"
1835#[derive(PartialEq, Eq)]
1836enum Bool { True, False }
1837
1838trait Foo {
1839    const BOOL: Bool;
1840}
1841
1842impl Foo for usize {
1843    const BOOL: Bool = Bool::True;
1844}
1845
1846fn main() {
1847    if <usize as Foo>::BOOL == Bool::True {
1848        println!("foo");
1849    }
1850}
1851"#,
1852        )
1853    }
1854
1855    #[test]
1856    fn const_non_bool() {
1857        cov_mark::check!(not_applicable_non_bool_const);
1858        check_assist_not_applicable(
1859            convert_bool_to_enum,
1860            r#"
1861const $0FOO: &str = "foo";
1862
1863fn main() {
1864    println!("{FOO}");
1865}
1866"#,
1867        )
1868    }
1869
1870    #[test]
1871    fn static_basic() {
1872        check_assist(
1873            convert_bool_to_enum,
1874            r#"
1875static mut $0BOOL: bool = true;
1876
1877fn main() {
1878    unsafe { BOOL = false };
1879    if unsafe { BOOL } {
1880        println!("foo");
1881    }
1882}
1883"#,
1884            r#"
1885#[derive(PartialEq, Eq)]
1886enum Bool { True, False }
1887
1888static mut BOOL: Bool = Bool::True;
1889
1890fn main() {
1891    unsafe { BOOL = Bool::False };
1892    if unsafe { BOOL == Bool::True } {
1893        println!("foo");
1894    }
1895}
1896"#,
1897        )
1898    }
1899
1900    #[test]
1901    fn static_non_bool() {
1902        cov_mark::check!(not_applicable_non_bool_static);
1903        check_assist_not_applicable(
1904            convert_bool_to_enum,
1905            r#"
1906static mut $0FOO: usize = 0;
1907
1908fn main() {
1909    if unsafe { FOO } == 0 {
1910        println!("foo");
1911    }
1912}
1913"#,
1914        )
1915    }
1916
1917    #[test]
1918    fn not_applicable_to_other_names() {
1919        check_assist_not_applicable(convert_bool_to_enum, "fn $0main() {}")
1920    }
1921}