1use either::Either;
2use hir::ModuleDef;
3use ide_db::imports::insert_use::insert_use_with_editor;
4use ide_db::text_edit::TextRange;
5use ide_db::{
6 FileId, FxHashSet,
7 assists::AssistId,
8 defs::Definition,
9 helpers::mod_path_to_ast_with_factory,
10 imports::insert_use::ImportScope,
11 search::{FileReference, UsageSearchResult},
12 source_change::SourceChangeBuilder,
13};
14use itertools::Itertools;
15use syntax::ast::edit::AstNodeEdit;
16use syntax::ast::syntax_factory::SyntaxFactory;
17use syntax::{
18 AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
19 ast::{self, HasName, edit::IndentLevel},
20};
21
22use crate::{
23 assist_context::{AssistContext, Assists},
24 utils,
25};
26
27pub(crate) fn convert_bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
57 let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
58 find_bool_node(ctx)?;
59 let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db());
60
61 let target = name.syntax().text_range();
62 acc.add(
63 AssistId::refactor_rewrite("convert_bool_to_enum"),
64 "Convert boolean to enum",
65 target,
66 |edit| {
67 let make = SyntaxFactory::without_mappings();
68 if let Some(ty) = &ty_annotation {
69 cov_mark::hit!(replaces_ty_annotation);
70 edit.replace(ty.syntax().text_range(), "Bool");
71 }
72
73 if let Some(initializer) = initializer {
74 replace_bool_expr(edit, initializer, &make);
75 }
76
77 let usages = definition.usages(&ctx.sema).all();
78 add_enum_def(edit, ctx, &usages, target_node, &target_module, &make);
79 let mut delayed_mutations = Vec::new();
80 replace_usages(
81 edit,
82 ctx,
83 usages,
84 definition,
85 &target_module,
86 &mut delayed_mutations,
87 &make,
88 );
89 for (file_id, scope, path) in delayed_mutations {
90 let editor = edit.make_editor(scope.as_syntax_node());
91 insert_use_with_editor(&scope, path, &ctx.config.insert_use, &editor);
92 edit.add_file_edits(file_id, editor);
93 }
94 },
95 )
96}
97
98struct BoolNodeData {
99 target_node: SyntaxNode,
100 name: ast::Name,
101 ty_annotation: Option<ast::Type>,
102 initializer: Option<ast::Expr>,
103 definition: Definition,
104}
105
106fn find_bool_node(ctx: &AssistContext<'_, '_>) -> Option<BoolNodeData> {
108 let name = ctx.find_node_at_offset::<ast::Name>()?;
109
110 if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
111 let def = ctx.sema.to_def(&ident_pat)?;
112 if !def.ty(ctx.db()).is_bool() {
113 cov_mark::hit!(not_applicable_non_bool_local);
114 return None;
115 }
116
117 let local_definition = Definition::Local(def);
118 match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
119 Either::Left(param) => Some(BoolNodeData {
120 target_node: param.syntax().clone(),
121 name,
122 ty_annotation: param.ty(),
123 initializer: None,
124 definition: local_definition,
125 }),
126 Either::Right(let_stmt) => Some(BoolNodeData {
127 target_node: let_stmt.syntax().clone(),
128 name,
129 ty_annotation: let_stmt.ty(),
130 initializer: let_stmt.initializer(),
131 definition: local_definition,
132 }),
133 }
134 } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
135 let def = ctx.sema.to_def(&const_)?;
136 if !def.ty(ctx.db()).is_bool() {
137 cov_mark::hit!(not_applicable_non_bool_const);
138 return None;
139 }
140
141 Some(BoolNodeData {
142 target_node: const_.syntax().clone(),
143 name,
144 ty_annotation: const_.ty(),
145 initializer: const_.body(),
146 definition: Definition::Const(def),
147 })
148 } else if let Some(static_) = name.syntax().parent().and_then(ast::Static::cast) {
149 let def = ctx.sema.to_def(&static_)?;
150 if !def.ty(ctx.db()).is_bool() {
151 cov_mark::hit!(not_applicable_non_bool_static);
152 return None;
153 }
154
155 Some(BoolNodeData {
156 target_node: static_.syntax().clone(),
157 name,
158 ty_annotation: static_.ty(),
159 initializer: static_.body(),
160 definition: Definition::Static(def),
161 })
162 } else {
163 let field = name.syntax().parent().and_then(ast::RecordField::cast)?;
164 if field.name()? != name {
165 return None;
166 }
167
168 let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?;
169 let def = ctx.sema.to_def(&field)?;
170 if !def.ty(ctx.db()).is_bool() {
171 cov_mark::hit!(not_applicable_non_bool_field);
172 return None;
173 }
174 Some(BoolNodeData {
175 target_node: adt.syntax().clone(),
176 name,
177 ty_annotation: field.ty(),
178 initializer: None,
179 definition: Definition::Field(def),
180 })
181 }
182}
183
184fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr, make: &SyntaxFactory) {
185 let expr_range = expr.syntax().text_range();
186 let enum_expr = bool_expr_to_enum_expr(expr, make);
187 edit.replace(expr_range, enum_expr.syntax().text())
188}
189
190fn bool_expr_to_enum_expr(expr: ast::Expr, make: &SyntaxFactory) -> ast::Expr {
192 let true_expr = make.expr_path(make.path_from_text("Bool::True"));
193 let false_expr = make.expr_path(make.path_from_text("Bool::False"));
194
195 if let ast::Expr::Literal(literal) = &expr {
196 match literal.kind() {
197 ast::LiteralKind::Bool(true) => true_expr,
198 ast::LiteralKind::Bool(false) => false_expr,
199 _ => expr,
200 }
201 } else {
202 make.expr_if(
203 expr,
204 make.tail_only_block_expr(true_expr),
205 Some(ast::ElseBranch::Block(make.tail_only_block_expr(false_expr))),
206 )
207 .into()
208 }
209}
210
211fn replace_usages(
213 edit: &mut SourceChangeBuilder,
214 ctx: &AssistContext<'_, '_>,
215 usages: UsageSearchResult,
216 target_definition: Definition,
217 target_module: &hir::Module,
218 delayed_mutations: &mut Vec<(FileId, ImportScope, ast::Path)>,
219 make: &SyntaxFactory,
220) {
221 for (file_id, references) in usages {
222 let vfs_file_id = file_id.file_id(ctx.db());
223 edit.edit_file(vfs_file_id);
224
225 let refs_with_imports =
226 augment_references_with_imports(ctx, references, target_module, make);
227
228 refs_with_imports.into_iter().rev().for_each(
229 |FileReferenceWithImport { range, name, import_data }| {
230 if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
232 cov_mark::hit!(replaces_record_pat_shorthand);
233
234 let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
235 if let Some(def) = definition {
236 replace_usages(
237 edit,
238 ctx,
239 def.usages(&ctx.sema).all(),
240 target_definition,
241 target_module,
242 delayed_mutations,
243 make,
244 )
245 }
246 } else if let Some(initializer) = find_assignment_usage(&name) {
247 cov_mark::hit!(replaces_assignment);
248
249 replace_bool_expr(edit, initializer, make);
250 } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
251 cov_mark::hit!(replaces_negation);
252
253 edit.replace(
254 prefix_expr.syntax().text_range(),
255 format!("{inner_expr} == Bool::False"),
256 );
257 } else if let Some((record_field, initializer)) = name
258 .as_name_ref()
259 .and_then(ast::RecordExprField::for_field_name)
260 .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
261 .and_then(|(got_field, _, _)| {
262 find_record_expr_usage(&name, got_field, target_definition)
263 })
264 {
265 cov_mark::hit!(replaces_record_expr);
266
267 let enum_expr = bool_expr_to_enum_expr(initializer, make);
268 utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
269 } else if let Some(pat) = find_record_pat_field_usage(&name) {
270 match pat {
271 ast::Pat::IdentPat(ident_pat) => {
272 cov_mark::hit!(replaces_record_pat);
273
274 let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
275 if let Some(def) = definition {
276 replace_usages(
277 edit,
278 ctx,
279 def.usages(&ctx.sema).all(),
280 target_definition,
281 target_module,
282 delayed_mutations,
283 make,
284 )
285 }
286 }
287 ast::Pat::LiteralPat(literal_pat) => {
288 cov_mark::hit!(replaces_literal_pat);
289
290 if let Some(expr) = literal_pat.literal().and_then(|literal| {
291 literal.syntax().ancestors().find_map(ast::Expr::cast)
292 }) {
293 replace_bool_expr(edit, expr, make);
294 }
295 }
296 _ => (),
297 }
298 } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
299 edit.replace(ty_annotation.syntax().text_range(), "Bool");
300 replace_bool_expr(edit, initializer, make);
301 } else if let Some(receiver) = find_method_call_expr_usage(&name) {
302 edit.replace(
303 receiver.syntax().text_range(),
304 format!("({receiver} == Bool::True)"),
305 );
306 } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
307 if let Some((record_field, expr)) =
309 name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
310 |record_field| record_field.expr().map(|expr| (record_field, expr)),
311 )
312 {
313 utils::replace_record_field_expr(
314 ctx,
315 edit,
316 record_field,
317 make.expr_bin_op(
318 expr,
319 ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
320 make.expr_path(make.path_from_text("Bool::True")),
321 ),
322 );
323 } else {
324 edit.replace(range, format!("{} == Bool::True", name.text()));
325 }
326 }
327
328 if let Some((scope, path)) = import_data {
330 delayed_mutations.push((vfs_file_id, scope, path));
331 }
332 },
333 )
334 }
335}
336
337struct FileReferenceWithImport {
338 range: TextRange,
339 name: ast::NameLike,
340 import_data: Option<(ImportScope, ast::Path)>,
341}
342
343fn augment_references_with_imports(
344 ctx: &AssistContext<'_, '_>,
345 references: Vec<FileReference>,
346 target_module: &hir::Module,
347 make: &SyntaxFactory,
348) -> Vec<FileReferenceWithImport> {
349 let mut visited_modules = FxHashSet::default();
350
351 let edition = target_module.krate(ctx.db()).edition(ctx.db());
352 references
353 .into_iter()
354 .filter_map(|FileReference { range, name, .. }| {
355 let name = name.into_name_like()?;
356 ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
357 })
358 .map(|(range, name, ref_module)| {
359 let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
361 && !visited_modules.contains(&ref_module)
362 {
363 visited_modules.insert(ref_module);
364
365 ImportScope::find_insert_use_container(name.syntax(), &ctx.sema).and_then(
366 |import_scope| {
367 let cfg = ctx.config.find_path_config(
368 ctx.sema.is_nightly(target_module.krate(ctx.sema.db)),
369 );
370 let path = ref_module
371 .find_use_path(
372 ctx.sema.db,
373 ModuleDef::Module(*target_module),
374 ctx.config.insert_use.prefix_kind,
375 cfg,
376 )
377 .map(|mod_path| {
378 make.path_concat(
379 mod_path_to_ast_with_factory(make, &mod_path, edition),
380 make.path_from_text("Bool"),
381 )
382 })?;
383
384 Some((import_scope, path))
385 },
386 )
387 } else {
388 None
389 };
390
391 FileReferenceWithImport { range, name, import_data }
392 })
393 .collect()
394}
395
396fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> {
397 let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?;
398
399 if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) {
400 cov_mark::hit!(dont_assign_incorrect_ref);
401 return None;
402 }
403
404 if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
405 bin_expr.rhs()
406 } else {
407 None
408 }
409}
410
411fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> {
412 let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
413
414 if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
415 cov_mark::hit!(dont_overwrite_expression_inside_negation);
416 return None;
417 }
418
419 if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
420 let inner_expr = prefix_expr.expr()?;
421 Some((prefix_expr, inner_expr))
422 } else {
423 None
424 }
425}
426
427fn find_record_expr_usage(
428 name: &ast::NameLike,
429 got_field: hir::Field,
430 target_definition: Definition,
431) -> Option<(ast::RecordExprField, ast::Expr)> {
432 let name_ref = name.as_name_ref()?;
433 let record_field = ast::RecordExprField::for_field_name(name_ref)?;
434 let initializer = record_field.expr()?;
435
436 match target_definition {
437 Definition::Field(expected_field) if got_field == expected_field => {
438 Some((record_field, initializer))
439 }
440 _ => None,
441 }
442}
443
444fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
445 let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?;
446 let pat = record_pat_field.pat()?;
447
448 match pat {
449 ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat),
450 _ => None,
451 }
452}
453
454fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
455 let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
456 const_.syntax().parent().and_then(ast::AssocItemList::cast)?;
457
458 Some((const_.ty()?, const_.body()?))
459}
460
461fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
462 let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
463 let receiver = method_call.receiver()?;
464
465 if !receiver.syntax().descendants().contains(name.syntax()) {
466 return None;
467 }
468
469 Some(receiver)
470}
471
472fn add_enum_def(
474 edit: &mut SourceChangeBuilder,
475 ctx: &AssistContext<'_, '_>,
476 usages: &UsageSearchResult,
477 target_node: SyntaxNode,
478 target_module: &hir::Module,
479 make: &SyntaxFactory,
480) -> Option<()> {
481 let insert_before = node_to_insert_before(target_node);
482
483 if ctx
484 .sema
485 .scope(&insert_before)?
486 .module()
487 .scope(ctx.db(), Some(*target_module))
488 .iter()
489 .any(|(name, _)| name.as_str() == "Bool")
490 {
491 return None;
492 }
493
494 let make_enum_pub = usages
495 .iter()
496 .flat_map(|(_, refs)| refs)
497 .filter_map(|FileReference { name, .. }| {
498 let name = name.clone().into_name_like()?;
499 ctx.sema.scope(name.syntax()).map(|scope| scope.module())
500 })
501 .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
502
503 let indent = IndentLevel::from_node(&insert_before);
504 let enum_def = make_bool_enum(make_enum_pub, make).reset_indent().indent(indent);
505
506 edit.insert(
507 insert_before.text_range().start(),
508 format!("{}\n\n{indent}", enum_def.syntax().text()),
509 );
510
511 Some(())
512}
513
514fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
518 target_node
519 .ancestors()
520 .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
521 .filter(|it| ast::Item::can_cast(it.kind()))
522 .last()
523 .unwrap_or(target_node)
524}
525
526fn make_bool_enum(make_pub: bool, make: &SyntaxFactory) -> ast::Enum {
527 let derive_eq = make.attr_outer(make.meta_token_tree(
528 make.ident_path("derive"),
529 make.token_tree(
530 T!['('],
531 vec![
532 NodeOrToken::Token(make.ident("PartialEq")),
533 NodeOrToken::Token(make.token(T![,])),
534 NodeOrToken::Token(make.whitespace(" ")),
535 NodeOrToken::Token(make.ident("Eq")),
536 ],
537 ),
538 ));
539 make.item_enum(
540 [derive_eq],
541 if make_pub { Some(make.visibility_pub()) } else { None },
542 make.name("Bool"),
543 None,
544 None,
545 make.variant_list(vec![
546 make.variant(None, make.name("True"), None, None),
547 make.variant(None, make.name("False"), None, None),
548 ]),
549 )
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 use crate::tests::{check_assist, check_assist_not_applicable};
557
558 #[test]
559 fn parameter_with_first_param_usage() {
560 check_assist(
561 convert_bool_to_enum,
562 r#"
563fn function($0foo: bool, bar: bool) {
564 if foo {
565 println!("foo");
566 }
567}
568"#,
569 r#"
570#[derive(PartialEq, Eq)]
571enum Bool { True, False }
572
573fn function(foo: Bool, bar: bool) {
574 if foo == Bool::True {
575 println!("foo");
576 }
577}
578"#,
579 )
580 }
581
582 #[test]
583 fn no_duplicate_enums() {
584 check_assist(
585 convert_bool_to_enum,
586 r#"
587#[derive(PartialEq, Eq)]
588enum Bool { True, False }
589
590fn function(foo: bool, $0bar: bool) {
591 if bar {
592 println!("bar");
593 }
594}
595"#,
596 r#"
597#[derive(PartialEq, Eq)]
598enum Bool { True, False }
599
600fn function(foo: bool, bar: Bool) {
601 if bar == Bool::True {
602 println!("bar");
603 }
604}
605"#,
606 )
607 }
608
609 #[test]
610 fn parameter_with_last_param_usage() {
611 check_assist(
612 convert_bool_to_enum,
613 r#"
614fn function(foo: bool, $0bar: bool) {
615 if bar {
616 println!("bar");
617 }
618}
619"#,
620 r#"
621#[derive(PartialEq, Eq)]
622enum Bool { True, False }
623
624fn function(foo: bool, bar: Bool) {
625 if bar == Bool::True {
626 println!("bar");
627 }
628}
629"#,
630 )
631 }
632
633 #[test]
634 fn parameter_with_middle_param_usage() {
635 check_assist(
636 convert_bool_to_enum,
637 r#"
638fn function(foo: bool, $0bar: bool, baz: bool) {
639 if bar {
640 println!("bar");
641 }
642}
643"#,
644 r#"
645#[derive(PartialEq, Eq)]
646enum Bool { True, False }
647
648fn function(foo: bool, bar: Bool, baz: bool) {
649 if bar == Bool::True {
650 println!("bar");
651 }
652}
653"#,
654 )
655 }
656
657 #[test]
658 fn parameter_with_closure_usage() {
659 check_assist(
660 convert_bool_to_enum,
661 r#"
662fn main() {
663 let foo = |$0bar: bool| bar;
664}
665"#,
666 r#"
667#[derive(PartialEq, Eq)]
668enum Bool { True, False }
669
670fn main() {
671 let foo = |bar: Bool| bar == Bool::True;
672}
673"#,
674 )
675 }
676
677 #[test]
678 fn local_variable_with_usage() {
679 check_assist(
680 convert_bool_to_enum,
681 r#"
682fn main() {
683 let $0foo = true;
684
685 if foo {
686 println!("foo");
687 }
688}
689"#,
690 r#"
691#[derive(PartialEq, Eq)]
692enum Bool { True, False }
693
694fn main() {
695 let foo = Bool::True;
696
697 if foo == Bool::True {
698 println!("foo");
699 }
700}
701"#,
702 )
703 }
704
705 #[test]
706 fn local_variable_with_usage_negated() {
707 cov_mark::check!(replaces_negation);
708 check_assist(
709 convert_bool_to_enum,
710 r#"
711fn main() {
712 let $0foo = true;
713
714 if !foo {
715 println!("foo");
716 }
717}
718"#,
719 r#"
720#[derive(PartialEq, Eq)]
721enum Bool { True, False }
722
723fn main() {
724 let foo = Bool::True;
725
726 if foo == Bool::False {
727 println!("foo");
728 }
729}
730"#,
731 )
732 }
733
734 #[test]
735 fn local_variable_with_type_annotation() {
736 cov_mark::check!(replaces_ty_annotation);
737 check_assist(
738 convert_bool_to_enum,
739 r#"
740fn main() {
741 let $0foo: bool = false;
742}
743"#,
744 r#"
745#[derive(PartialEq, Eq)]
746enum Bool { True, False }
747
748fn main() {
749 let foo: Bool = Bool::False;
750}
751"#,
752 )
753 }
754
755 #[test]
756 fn local_variable_with_non_literal_initializer() {
757 check_assist(
758 convert_bool_to_enum,
759 r#"
760fn main() {
761 let $0foo = 1 == 2;
762}
763"#,
764 r#"
765#[derive(PartialEq, Eq)]
766enum Bool { True, False }
767
768fn main() {
769 let foo = if 1 == 2 { Bool::True } else { Bool::False };
770}
771"#,
772 )
773 }
774
775 #[test]
776 fn local_variable_binexpr_usage() {
777 check_assist(
778 convert_bool_to_enum,
779 r#"
780fn main() {
781 let $0foo = false;
782 let bar = true;
783
784 if !foo && bar {
785 println!("foobar");
786 }
787}
788"#,
789 r#"
790#[derive(PartialEq, Eq)]
791enum Bool { True, False }
792
793fn main() {
794 let foo = Bool::False;
795 let bar = true;
796
797 if foo == Bool::False && bar {
798 println!("foobar");
799 }
800}
801"#,
802 )
803 }
804
805 #[test]
806 fn local_variable_unop_usage() {
807 check_assist(
808 convert_bool_to_enum,
809 r#"
810fn main() {
811 let $0foo = true;
812
813 if *&foo {
814 println!("foobar");
815 }
816}
817"#,
818 r#"
819#[derive(PartialEq, Eq)]
820enum Bool { True, False }
821
822fn main() {
823 let foo = Bool::True;
824
825 if *&foo == Bool::True {
826 println!("foobar");
827 }
828}
829"#,
830 )
831 }
832
833 #[test]
834 fn local_variable_assigned_later() {
835 cov_mark::check!(replaces_assignment);
836 check_assist(
837 convert_bool_to_enum,
838 r#"
839fn main() {
840 let $0foo: bool;
841 foo = true;
842}
843"#,
844 r#"
845#[derive(PartialEq, Eq)]
846enum Bool { True, False }
847
848fn main() {
849 let foo: Bool;
850 foo = Bool::True;
851}
852"#,
853 )
854 }
855
856 #[test]
857 fn local_variable_does_not_apply_recursively() {
858 check_assist(
859 convert_bool_to_enum,
860 r#"
861fn main() {
862 let $0foo = true;
863 let bar = !foo;
864
865 if bar {
866 println!("bar");
867 }
868}
869"#,
870 r#"
871#[derive(PartialEq, Eq)]
872enum Bool { True, False }
873
874fn main() {
875 let foo = Bool::True;
876 let bar = foo == Bool::False;
877
878 if bar {
879 println!("bar");
880 }
881}
882"#,
883 )
884 }
885
886 #[test]
887 fn local_variable_nested_in_negation() {
888 cov_mark::check!(dont_overwrite_expression_inside_negation);
889 check_assist(
890 convert_bool_to_enum,
891 r#"
892fn main() {
893 if !"foo".chars().any(|c| {
894 let $0foo = true;
895 foo
896 }) {
897 println!("foo");
898 }
899}
900"#,
901 r#"
902#[derive(PartialEq, Eq)]
903enum Bool { True, False }
904
905fn main() {
906 if !"foo".chars().any(|c| {
907 let foo = Bool::True;
908 foo == Bool::True
909 }) {
910 println!("foo");
911 }
912}
913"#,
914 )
915 }
916
917 #[test]
918 fn local_variable_non_bool() {
919 cov_mark::check!(not_applicable_non_bool_local);
920 check_assist_not_applicable(
921 convert_bool_to_enum,
922 r#"
923fn main() {
924 let $0foo = 1;
925}
926"#,
927 )
928 }
929
930 #[test]
931 fn local_variable_cursor_not_on_ident() {
932 check_assist_not_applicable(
933 convert_bool_to_enum,
934 r#"
935fn main() {
936 let foo = $0true;
937}
938"#,
939 )
940 }
941
942 #[test]
943 fn local_variable_non_ident_pat() {
944 check_assist_not_applicable(
945 convert_bool_to_enum,
946 r#"
947fn main() {
948 let ($0foo, bar) = (true, false);
949}
950"#,
951 )
952 }
953
954 #[test]
955 fn local_var_init_struct_usage() {
956 check_assist(
957 convert_bool_to_enum,
958 r#"
959struct Foo {
960 foo: bool,
961}
962
963fn main() {
964 let $0foo = true;
965 let s = Foo { foo };
966}
967"#,
968 r#"
969struct Foo {
970 foo: bool,
971}
972
973#[derive(PartialEq, Eq)]
974enum Bool { True, False }
975
976fn main() {
977 let foo = Bool::True;
978 let s = Foo { foo: foo == Bool::True };
979}
980"#,
981 )
982 }
983
984 #[test]
985 fn local_var_init_struct_usage_in_macro() {
986 check_assist(
987 convert_bool_to_enum,
988 r#"
989struct Struct {
990 boolean: bool,
991}
992
993macro_rules! identity {
994 ($body:expr) => {
995 $body
996 }
997}
998
999fn new() -> Struct {
1000 let $0boolean = true;
1001 identity![Struct { boolean }]
1002}
1003"#,
1004 r#"
1005struct Struct {
1006 boolean: bool,
1007}
1008
1009macro_rules! identity {
1010 ($body:expr) => {
1011 $body
1012 }
1013}
1014
1015#[derive(PartialEq, Eq)]
1016enum Bool { True, False }
1017
1018fn new() -> Struct {
1019 let boolean = Bool::True;
1020 identity![Struct { boolean: boolean == Bool::True }]
1021}
1022"#,
1023 )
1024 }
1025
1026 #[test]
1027 fn field_struct_basic() {
1028 cov_mark::check!(replaces_record_expr);
1029 check_assist(
1030 convert_bool_to_enum,
1031 r#"
1032struct Foo {
1033 $0bar: bool,
1034 baz: bool,
1035}
1036
1037fn main() {
1038 let foo = Foo { bar: true, baz: false };
1039
1040 if foo.bar {
1041 println!("foo");
1042 }
1043}
1044"#,
1045 r#"
1046#[derive(PartialEq, Eq)]
1047enum Bool { True, False }
1048
1049struct Foo {
1050 bar: Bool,
1051 baz: bool,
1052}
1053
1054fn main() {
1055 let foo = Foo { bar: Bool::True, baz: false };
1056
1057 if foo.bar == Bool::True {
1058 println!("foo");
1059 }
1060}
1061"#,
1062 )
1063 }
1064
1065 #[test]
1066 fn field_enum_basic() {
1067 cov_mark::check!(replaces_record_pat);
1068 check_assist(
1069 convert_bool_to_enum,
1070 r#"
1071enum Foo {
1072 Foo,
1073 Bar { $0bar: bool },
1074}
1075
1076fn main() {
1077 let foo = Foo::Bar { bar: true };
1078
1079 if let Foo::Bar { bar: baz } = foo {
1080 if baz {
1081 println!("foo");
1082 }
1083 }
1084}
1085"#,
1086 r#"
1087#[derive(PartialEq, Eq)]
1088enum Bool { True, False }
1089
1090enum Foo {
1091 Foo,
1092 Bar { bar: Bool },
1093}
1094
1095fn main() {
1096 let foo = Foo::Bar { bar: Bool::True };
1097
1098 if let Foo::Bar { bar: baz } = foo {
1099 if baz == Bool::True {
1100 println!("foo");
1101 }
1102 }
1103}
1104"#,
1105 )
1106 }
1107
1108 #[test]
1109 fn field_enum_cross_file() {
1110 check_assist(
1112 convert_bool_to_enum,
1113 r#"
1114//- /foo.rs
1115pub enum Foo {
1116 Foo,
1117 Bar { $0bar: bool },
1118}
1119
1120fn foo() {
1121 let foo = Foo::Bar { bar: true };
1122}
1123
1124//- /main.rs
1125use foo::Foo;
1126
1127mod foo;
1128
1129fn main() {
1130 let foo = Foo::Bar { bar: false };
1131}
1132"#,
1133 r#"
1134//- /foo.rs
1135#[derive(PartialEq, Eq)]
1136pub enum Bool { True, False }
1137
1138pub enum Foo {
1139 Foo,
1140 Bar { bar: Bool },
1141}
1142
1143fn foo() {
1144 let foo = Foo::Bar { bar: Bool::True };
1145}
1146
1147//- /main.rs
1148use foo::{Bool, Foo};
1149
1150mod foo;
1151
1152fn main() {
1153 let foo = Foo::Bar { bar: Bool::False };
1154}
1155"#,
1156 )
1157 }
1158
1159 #[test]
1160 fn field_enum_shorthand() {
1161 cov_mark::check!(replaces_record_pat_shorthand);
1162 check_assist(
1163 convert_bool_to_enum,
1164 r#"
1165enum Foo {
1166 Foo,
1167 Bar { $0bar: bool },
1168}
1169
1170fn main() {
1171 let foo = Foo::Bar { bar: true };
1172
1173 match foo {
1174 Foo::Bar { bar } => {
1175 if bar {
1176 println!("foo");
1177 }
1178 }
1179 _ => (),
1180 }
1181}
1182"#,
1183 r#"
1184#[derive(PartialEq, Eq)]
1185enum Bool { True, False }
1186
1187enum Foo {
1188 Foo,
1189 Bar { bar: Bool },
1190}
1191
1192fn main() {
1193 let foo = Foo::Bar { bar: Bool::True };
1194
1195 match foo {
1196 Foo::Bar { bar } => {
1197 if bar == Bool::True {
1198 println!("foo");
1199 }
1200 }
1201 _ => (),
1202 }
1203}
1204"#,
1205 )
1206 }
1207
1208 #[test]
1209 fn field_enum_replaces_literal_patterns() {
1210 cov_mark::check!(replaces_literal_pat);
1211 check_assist(
1212 convert_bool_to_enum,
1213 r#"
1214enum Foo {
1215 Foo,
1216 Bar { $0bar: bool },
1217}
1218
1219fn main() {
1220 let foo = Foo::Bar { bar: true };
1221
1222 if let Foo::Bar { bar: true } = foo {
1223 println!("foo");
1224 }
1225}
1226"#,
1227 r#"
1228#[derive(PartialEq, Eq)]
1229enum Bool { True, False }
1230
1231enum Foo {
1232 Foo,
1233 Bar { bar: Bool },
1234}
1235
1236fn main() {
1237 let foo = Foo::Bar { bar: Bool::True };
1238
1239 if let Foo::Bar { bar: Bool::True } = foo {
1240 println!("foo");
1241 }
1242}
1243"#,
1244 )
1245 }
1246
1247 #[test]
1248 fn field_enum_keeps_wildcard_patterns() {
1249 check_assist(
1250 convert_bool_to_enum,
1251 r#"
1252enum Foo {
1253 Foo,
1254 Bar { $0bar: bool },
1255}
1256
1257fn main() {
1258 let foo = Foo::Bar { bar: true };
1259
1260 if let Foo::Bar { bar: _ } = foo {
1261 println!("foo");
1262 }
1263}
1264"#,
1265 r#"
1266#[derive(PartialEq, Eq)]
1267enum Bool { True, False }
1268
1269enum Foo {
1270 Foo,
1271 Bar { bar: Bool },
1272}
1273
1274fn main() {
1275 let foo = Foo::Bar { bar: Bool::True };
1276
1277 if let Foo::Bar { bar: _ } = foo {
1278 println!("foo");
1279 }
1280}
1281"#,
1282 )
1283 }
1284
1285 #[test]
1286 fn field_union_basic() {
1287 check_assist(
1288 convert_bool_to_enum,
1289 r#"
1290union Foo {
1291 $0foo: bool,
1292 bar: usize,
1293}
1294
1295fn main() {
1296 let foo = Foo { foo: true };
1297
1298 if unsafe { foo.foo } {
1299 println!("foo");
1300 }
1301}
1302"#,
1303 r#"
1304#[derive(PartialEq, Eq)]
1305enum Bool { True, False }
1306
1307union Foo {
1308 foo: Bool,
1309 bar: usize,
1310}
1311
1312fn main() {
1313 let foo = Foo { foo: Bool::True };
1314
1315 if unsafe { foo.foo == Bool::True } {
1316 println!("foo");
1317 }
1318}
1319"#,
1320 )
1321 }
1322
1323 #[test]
1324 fn field_negated() {
1325 check_assist(
1326 convert_bool_to_enum,
1327 r#"
1328struct Foo {
1329 $0bar: bool,
1330}
1331
1332fn main() {
1333 let foo = Foo { bar: false };
1334
1335 if !foo.bar {
1336 println!("foo");
1337 }
1338}
1339"#,
1340 r#"
1341#[derive(PartialEq, Eq)]
1342enum Bool { True, False }
1343
1344struct Foo {
1345 bar: Bool,
1346}
1347
1348fn main() {
1349 let foo = Foo { bar: Bool::False };
1350
1351 if foo.bar == Bool::False {
1352 println!("foo");
1353 }
1354}
1355"#,
1356 )
1357 }
1358
1359 #[test]
1360 fn field_in_mod_properly_indented() {
1361 check_assist(
1362 convert_bool_to_enum,
1363 r#"
1364mod foo {
1365 struct Bar {
1366 $0baz: bool,
1367 }
1368
1369 impl Bar {
1370 fn new(baz: bool) -> Self {
1371 Self { baz }
1372 }
1373 }
1374}
1375"#,
1376 r#"
1377mod foo {
1378 #[derive(PartialEq, Eq)]
1379 enum Bool { True, False }
1380
1381 struct Bar {
1382 baz: Bool,
1383 }
1384
1385 impl Bar {
1386 fn new(baz: bool) -> Self {
1387 Self { baz: if baz { Bool::True } else { Bool::False } }
1388 }
1389 }
1390}
1391"#,
1392 )
1393 }
1394
1395 #[test]
1396 fn field_multiple_initializations() {
1397 check_assist(
1398 convert_bool_to_enum,
1399 r#"
1400struct Foo {
1401 $0bar: bool,
1402 baz: bool,
1403}
1404
1405fn main() {
1406 let foo1 = Foo { bar: true, baz: false };
1407 let foo2 = Foo { bar: false, baz: false };
1408
1409 if foo1.bar && foo2.bar {
1410 println!("foo");
1411 }
1412}
1413"#,
1414 r#"
1415#[derive(PartialEq, Eq)]
1416enum Bool { True, False }
1417
1418struct Foo {
1419 bar: Bool,
1420 baz: bool,
1421}
1422
1423fn main() {
1424 let foo1 = Foo { bar: Bool::True, baz: false };
1425 let foo2 = Foo { bar: Bool::False, baz: false };
1426
1427 if foo1.bar == Bool::True && foo2.bar == Bool::True {
1428 println!("foo");
1429 }
1430}
1431"#,
1432 )
1433 }
1434
1435 #[test]
1436 fn field_assigned_to_another() {
1437 cov_mark::check!(dont_assign_incorrect_ref);
1438 check_assist(
1439 convert_bool_to_enum,
1440 r#"
1441struct Foo {
1442 $0foo: bool,
1443}
1444
1445struct Bar {
1446 bar: bool,
1447}
1448
1449fn main() {
1450 let foo = Foo { foo: true };
1451 let mut bar = Bar { bar: true };
1452
1453 bar.bar = foo.foo;
1454}
1455"#,
1456 r#"
1457#[derive(PartialEq, Eq)]
1458enum Bool { True, False }
1459
1460struct Foo {
1461 foo: Bool,
1462}
1463
1464struct Bar {
1465 bar: bool,
1466}
1467
1468fn main() {
1469 let foo = Foo { foo: Bool::True };
1470 let mut bar = Bar { bar: true };
1471
1472 bar.bar = foo.foo == Bool::True;
1473}
1474"#,
1475 )
1476 }
1477
1478 #[test]
1479 fn field_initialized_with_other() {
1480 check_assist(
1481 convert_bool_to_enum,
1482 r#"
1483struct Foo {
1484 $0foo: bool,
1485}
1486
1487struct Bar {
1488 bar: bool,
1489}
1490
1491fn main() {
1492 let foo = Foo { foo: true };
1493 let bar = Bar { bar: foo.foo };
1494}
1495"#,
1496 r#"
1497#[derive(PartialEq, Eq)]
1498enum Bool { True, False }
1499
1500struct Foo {
1501 foo: Bool,
1502}
1503
1504struct Bar {
1505 bar: bool,
1506}
1507
1508fn main() {
1509 let foo = Foo { foo: Bool::True };
1510 let bar = Bar { bar: foo.foo == Bool::True };
1511}
1512"#,
1513 )
1514 }
1515
1516 #[test]
1517 fn field_method_chain_usage() {
1518 check_assist(
1519 convert_bool_to_enum,
1520 r#"
1521struct Foo {
1522 $0bool: bool,
1523}
1524
1525fn main() {
1526 let foo = Foo { bool: true };
1527
1528 foo.bool.then(|| 2);
1529}
1530"#,
1531 r#"
1532#[derive(PartialEq, Eq)]
1533enum Bool { True, False }
1534
1535struct Foo {
1536 bool: Bool,
1537}
1538
1539fn main() {
1540 let foo = Foo { bool: Bool::True };
1541
1542 (foo.bool == Bool::True).then(|| 2);
1543}
1544"#,
1545 )
1546 }
1547
1548 #[test]
1549 fn field_in_macro() {
1550 check_assist(
1551 convert_bool_to_enum,
1552 r#"
1553struct Struct {
1554 $0boolean: bool,
1555}
1556
1557fn boolean(x: Struct) {
1558 let Struct { boolean } = x;
1559}
1560
1561macro_rules! identity { ($body:expr) => { $body } }
1562
1563fn new() -> Struct {
1564 identity!(Struct { boolean: true })
1565}
1566"#,
1567 r#"
1568#[derive(PartialEq, Eq)]
1569enum Bool { True, False }
1570
1571struct Struct {
1572 boolean: Bool,
1573}
1574
1575fn boolean(x: Struct) {
1576 let Struct { boolean } = x;
1577}
1578
1579macro_rules! identity { ($body:expr) => { $body } }
1580
1581fn new() -> Struct {
1582 identity!(Struct { boolean: Bool::True })
1583}
1584"#,
1585 )
1586 }
1587
1588 #[test]
1589 fn field_non_bool() {
1590 cov_mark::check!(not_applicable_non_bool_field);
1591 check_assist_not_applicable(
1592 convert_bool_to_enum,
1593 r#"
1594struct Foo {
1595 $0bar: usize,
1596}
1597
1598fn main() {
1599 let foo = Foo { bar: 1 };
1600}
1601"#,
1602 )
1603 }
1604
1605 #[test]
1606 fn const_basic() {
1607 check_assist(
1608 convert_bool_to_enum,
1609 r#"
1610const $0FOO: bool = false;
1611
1612fn main() {
1613 if FOO {
1614 println!("foo");
1615 }
1616}
1617"#,
1618 r#"
1619#[derive(PartialEq, Eq)]
1620enum Bool { True, False }
1621
1622const FOO: Bool = Bool::False;
1623
1624fn main() {
1625 if FOO == Bool::True {
1626 println!("foo");
1627 }
1628}
1629"#,
1630 )
1631 }
1632
1633 #[test]
1634 fn const_in_module() {
1635 check_assist(
1636 convert_bool_to_enum,
1637 r#"
1638fn main() {
1639 if foo::FOO {
1640 println!("foo");
1641 }
1642}
1643
1644mod foo {
1645 pub const $0FOO: bool = true;
1646}
1647"#,
1648 r#"
1649use foo::Bool;
1650
1651fn main() {
1652 if foo::FOO == Bool::True {
1653 println!("foo");
1654 }
1655}
1656
1657mod foo {
1658 #[derive(PartialEq, Eq)]
1659 pub enum Bool { True, False }
1660
1661 pub const FOO: Bool = Bool::True;
1662}
1663"#,
1664 )
1665 }
1666
1667 #[test]
1668 fn const_in_module_with_import() {
1669 check_assist(
1670 convert_bool_to_enum,
1671 r#"
1672fn main() {
1673 use foo::FOO;
1674
1675 if FOO {
1676 println!("foo");
1677 }
1678}
1679
1680mod foo {
1681 pub const $0FOO: bool = true;
1682}
1683"#,
1684 r#"
1685use foo::Bool;
1686
1687fn main() {
1688 use foo::FOO;
1689
1690 if FOO == Bool::True {
1691 println!("foo");
1692 }
1693}
1694
1695mod foo {
1696 #[derive(PartialEq, Eq)]
1697 pub enum Bool { True, False }
1698
1699 pub const FOO: Bool = Bool::True;
1700}
1701"#,
1702 )
1703 }
1704
1705 #[test]
1706 fn const_cross_file() {
1707 check_assist(
1708 convert_bool_to_enum,
1709 r#"
1710//- /main.rs
1711mod foo;
1712
1713fn main() {
1714 if foo::FOO {
1715 println!("foo");
1716 }
1717}
1718
1719//- /foo.rs
1720pub const $0FOO: bool = true;
1721"#,
1722 r#"
1723//- /main.rs
1724use foo::Bool;
1725
1726mod foo;
1727
1728fn main() {
1729 if foo::FOO == Bool::True {
1730 println!("foo");
1731 }
1732}
1733
1734//- /foo.rs
1735#[derive(PartialEq, Eq)]
1736pub enum Bool { True, False }
1737
1738pub const FOO: Bool = Bool::True;
1739"#,
1740 )
1741 }
1742
1743 #[test]
1744 fn const_cross_file_and_module() {
1745 check_assist(
1746 convert_bool_to_enum,
1747 r#"
1748//- /main.rs
1749mod foo;
1750
1751fn main() {
1752 use foo::bar;
1753
1754 if bar::BAR {
1755 println!("foo");
1756 }
1757}
1758
1759//- /foo.rs
1760pub mod bar {
1761 pub const $0BAR: bool = false;
1762}
1763"#,
1764 r#"
1765//- /main.rs
1766use foo::bar::Bool;
1767
1768mod foo;
1769
1770fn main() {
1771 use foo::bar;
1772
1773 if bar::BAR == Bool::True {
1774 println!("foo");
1775 }
1776}
1777
1778//- /foo.rs
1779pub mod bar {
1780 #[derive(PartialEq, Eq)]
1781 pub enum Bool { True, False }
1782
1783 pub const BAR: Bool = Bool::False;
1784}
1785"#,
1786 )
1787 }
1788
1789 #[test]
1790 fn const_in_impl_cross_file() {
1791 check_assist(
1792 convert_bool_to_enum,
1793 r#"
1794//- /main.rs
1795mod foo;
1796
1797struct Foo;
1798
1799impl Foo {
1800 pub const $0BOOL: bool = true;
1801}
1802
1803//- /foo.rs
1804use crate::Foo;
1805
1806fn foo() -> bool {
1807 Foo::BOOL
1808}
1809"#,
1810 r#"
1811//- /main.rs
1812mod foo;
1813
1814struct Foo;
1815
1816#[derive(PartialEq, Eq)]
1817pub enum Bool { True, False }
1818
1819impl Foo {
1820 pub const BOOL: Bool = Bool::True;
1821}
1822
1823//- /foo.rs
1824use crate::{Bool, Foo};
1825
1826fn foo() -> bool {
1827 Foo::BOOL == Bool::True
1828}
1829"#,
1830 )
1831 }
1832
1833 #[test]
1834 fn const_in_trait() {
1835 check_assist(
1836 convert_bool_to_enum,
1837 r#"
1838trait Foo {
1839 const $0BOOL: bool;
1840}
1841
1842impl Foo for usize {
1843 const BOOL: bool = true;
1844}
1845
1846fn main() {
1847 if <usize as Foo>::BOOL {
1848 println!("foo");
1849 }
1850}
1851"#,
1852 r#"
1853#[derive(PartialEq, Eq)]
1854enum Bool { True, False }
1855
1856trait Foo {
1857 const BOOL: Bool;
1858}
1859
1860impl Foo for usize {
1861 const BOOL: Bool = Bool::True;
1862}
1863
1864fn main() {
1865 if <usize as Foo>::BOOL == Bool::True {
1866 println!("foo");
1867 }
1868}
1869"#,
1870 )
1871 }
1872
1873 #[test]
1874 fn const_non_bool() {
1875 cov_mark::check!(not_applicable_non_bool_const);
1876 check_assist_not_applicable(
1877 convert_bool_to_enum,
1878 r#"
1879const $0FOO: &str = "foo";
1880
1881fn main() {
1882 println!("{FOO}");
1883}
1884"#,
1885 )
1886 }
1887
1888 #[test]
1889 fn static_basic() {
1890 check_assist(
1891 convert_bool_to_enum,
1892 r#"
1893static mut $0BOOL: bool = true;
1894
1895fn main() {
1896 unsafe { BOOL = false };
1897 if unsafe { BOOL } {
1898 println!("foo");
1899 }
1900}
1901"#,
1902 r#"
1903#[derive(PartialEq, Eq)]
1904enum Bool { True, False }
1905
1906static mut BOOL: Bool = Bool::True;
1907
1908fn main() {
1909 unsafe { BOOL = Bool::False };
1910 if unsafe { BOOL == Bool::True } {
1911 println!("foo");
1912 }
1913}
1914"#,
1915 )
1916 }
1917
1918 #[test]
1919 fn static_non_bool() {
1920 cov_mark::check!(not_applicable_non_bool_static);
1921 check_assist_not_applicable(
1922 convert_bool_to_enum,
1923 r#"
1924static mut $0FOO: usize = 0;
1925
1926fn main() {
1927 if unsafe { FOO } == 0 {
1928 println!("foo");
1929 }
1930}
1931"#,
1932 )
1933 }
1934
1935 #[test]
1936 fn not_applicable_to_other_names() {
1937 check_assist_not_applicable(convert_bool_to_enum, "fn $0main() {}")
1938 }
1939}