1use either::Either;
2use hir::ModuleDef;
3use ide_db::text_edit::TextRange;
4use ide_db::{
5 FxHashSet,
6 assists::AssistId,
7 defs::Definition,
8 helpers::mod_path_to_ast,
9 imports::insert_use::{ImportScope, insert_use},
10 search::{FileReference, UsageSearchResult},
11 source_change::SourceChangeBuilder,
12};
13use itertools::Itertools;
14use syntax::{
15 AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
16 ast::{self, HasName, edit::IndentLevel, edit_in_place::Indent, make},
17};
18
19use crate::{
20 assist_context::{AssistContext, Assists},
21 utils,
22};
23
24pub(crate) fn convert_bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
54 let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
55 find_bool_node(ctx)?;
56 let target_module = ctx.sema.scope(&target_node)?.module().nearest_non_block_module(ctx.db());
57
58 let target = name.syntax().text_range();
59 acc.add(
60 AssistId::refactor_rewrite("convert_bool_to_enum"),
61 "Convert boolean to enum",
62 target,
63 |edit| {
64 if let Some(ty) = &ty_annotation {
65 cov_mark::hit!(replaces_ty_annotation);
66 edit.replace(ty.syntax().text_range(), "Bool");
67 }
68
69 if let Some(initializer) = initializer {
70 replace_bool_expr(edit, initializer);
71 }
72
73 let usages = definition.usages(&ctx.sema).all();
74 add_enum_def(edit, ctx, &usages, target_node, &target_module);
75 let mut delayed_mutations = Vec::new();
76 replace_usages(edit, ctx, usages, definition, &target_module, &mut delayed_mutations);
77 for (scope, path) in delayed_mutations {
78 insert_use(&scope, path, &ctx.config.insert_use);
79 }
80 },
81 )
82}
83
84struct BoolNodeData {
85 target_node: SyntaxNode,
86 name: ast::Name,
87 ty_annotation: Option<ast::Type>,
88 initializer: Option<ast::Expr>,
89 definition: Definition,
90}
91
92fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
94 let name = ctx.find_node_at_offset::<ast::Name>()?;
95
96 if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
97 let def = ctx.sema.to_def(&ident_pat)?;
98 if !def.ty(ctx.db()).is_bool() {
99 cov_mark::hit!(not_applicable_non_bool_local);
100 return None;
101 }
102
103 let local_definition = Definition::Local(def);
104 match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
105 Either::Left(param) => Some(BoolNodeData {
106 target_node: param.syntax().clone(),
107 name,
108 ty_annotation: param.ty(),
109 initializer: None,
110 definition: local_definition,
111 }),
112 Either::Right(let_stmt) => Some(BoolNodeData {
113 target_node: let_stmt.syntax().clone(),
114 name,
115 ty_annotation: let_stmt.ty(),
116 initializer: let_stmt.initializer(),
117 definition: local_definition,
118 }),
119 }
120 } else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
121 let def = ctx.sema.to_def(&const_)?;
122 if !def.ty(ctx.db()).is_bool() {
123 cov_mark::hit!(not_applicable_non_bool_const);
124 return None;
125 }
126
127 Some(BoolNodeData {
128 target_node: const_.syntax().clone(),
129 name,
130 ty_annotation: const_.ty(),
131 initializer: const_.body(),
132 definition: Definition::Const(def),
133 })
134 } else if let Some(static_) = name.syntax().parent().and_then(ast::Static::cast) {
135 let def = ctx.sema.to_def(&static_)?;
136 if !def.ty(ctx.db()).is_bool() {
137 cov_mark::hit!(not_applicable_non_bool_static);
138 return None;
139 }
140
141 Some(BoolNodeData {
142 target_node: static_.syntax().clone(),
143 name,
144 ty_annotation: static_.ty(),
145 initializer: static_.body(),
146 definition: Definition::Static(def),
147 })
148 } else {
149 let field = name.syntax().parent().and_then(ast::RecordField::cast)?;
150 if field.name()? != name {
151 return None;
152 }
153
154 let adt = field.syntax().ancestors().find_map(ast::Adt::cast)?;
155 let def = ctx.sema.to_def(&field)?;
156 if !def.ty(ctx.db()).is_bool() {
157 cov_mark::hit!(not_applicable_non_bool_field);
158 return None;
159 }
160 Some(BoolNodeData {
161 target_node: adt.syntax().clone(),
162 name,
163 ty_annotation: field.ty(),
164 initializer: None,
165 definition: Definition::Field(def),
166 })
167 }
168}
169
170fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {
171 let expr_range = expr.syntax().text_range();
172 let enum_expr = bool_expr_to_enum_expr(expr);
173 edit.replace(expr_range, enum_expr.syntax().text())
174}
175
176fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
178 let true_expr = make::expr_path(make::path_from_text("Bool::True"));
179 let false_expr = make::expr_path(make::path_from_text("Bool::False"));
180
181 if let ast::Expr::Literal(literal) = &expr {
182 match literal.kind() {
183 ast::LiteralKind::Bool(true) => true_expr,
184 ast::LiteralKind::Bool(false) => false_expr,
185 _ => expr,
186 }
187 } else {
188 make::expr_if(
189 expr,
190 make::tail_only_block_expr(true_expr),
191 Some(ast::ElseBranch::Block(make::tail_only_block_expr(false_expr))),
192 )
193 .into()
194 }
195}
196
197fn replace_usages(
199 edit: &mut SourceChangeBuilder,
200 ctx: &AssistContext<'_>,
201 usages: UsageSearchResult,
202 target_definition: Definition,
203 target_module: &hir::Module,
204 delayed_mutations: &mut Vec<(ImportScope, ast::Path)>,
205) {
206 for (file_id, references) in usages {
207 edit.edit_file(file_id.file_id(ctx.db()));
208
209 let refs_with_imports = augment_references_with_imports(ctx, references, target_module);
210
211 refs_with_imports.into_iter().rev().for_each(
212 |FileReferenceWithImport { range, name, import_data }| {
213 if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
215 cov_mark::hit!(replaces_record_pat_shorthand);
216
217 let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
218 if let Some(def) = definition {
219 replace_usages(
220 edit,
221 ctx,
222 def.usages(&ctx.sema).all(),
223 target_definition,
224 target_module,
225 delayed_mutations,
226 )
227 }
228 } else if let Some(initializer) = find_assignment_usage(&name) {
229 cov_mark::hit!(replaces_assignment);
230
231 replace_bool_expr(edit, initializer);
232 } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
233 cov_mark::hit!(replaces_negation);
234
235 edit.replace(
236 prefix_expr.syntax().text_range(),
237 format!("{inner_expr} == Bool::False"),
238 );
239 } else if let Some((record_field, initializer)) = name
240 .as_name_ref()
241 .and_then(ast::RecordExprField::for_field_name)
242 .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
243 .and_then(|(got_field, _, _)| {
244 find_record_expr_usage(&name, got_field, target_definition)
245 })
246 {
247 cov_mark::hit!(replaces_record_expr);
248
249 let enum_expr = bool_expr_to_enum_expr(initializer);
250 utils::replace_record_field_expr(ctx, edit, record_field, enum_expr);
251 } else if let Some(pat) = find_record_pat_field_usage(&name) {
252 match pat {
253 ast::Pat::IdentPat(ident_pat) => {
254 cov_mark::hit!(replaces_record_pat);
255
256 let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
257 if let Some(def) = definition {
258 replace_usages(
259 edit,
260 ctx,
261 def.usages(&ctx.sema).all(),
262 target_definition,
263 target_module,
264 delayed_mutations,
265 )
266 }
267 }
268 ast::Pat::LiteralPat(literal_pat) => {
269 cov_mark::hit!(replaces_literal_pat);
270
271 if let Some(expr) = literal_pat.literal().and_then(|literal| {
272 literal.syntax().ancestors().find_map(ast::Expr::cast)
273 }) {
274 replace_bool_expr(edit, expr);
275 }
276 }
277 _ => (),
278 }
279 } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
280 edit.replace(ty_annotation.syntax().text_range(), "Bool");
281 replace_bool_expr(edit, initializer);
282 } else if let Some(receiver) = find_method_call_expr_usage(&name) {
283 edit.replace(
284 receiver.syntax().text_range(),
285 format!("({receiver} == Bool::True)"),
286 );
287 } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
288 if let Some((record_field, expr)) =
290 name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
291 |record_field| record_field.expr().map(|expr| (record_field, expr)),
292 )
293 {
294 utils::replace_record_field_expr(
295 ctx,
296 edit,
297 record_field,
298 make::expr_bin_op(
299 expr,
300 ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
301 make::expr_path(make::path_from_text("Bool::True")),
302 ),
303 );
304 } else {
305 edit.replace(range, format!("{} == Bool::True", name.text()));
306 }
307 }
308
309 if let Some((scope, path)) = import_data {
311 let scope = edit.make_import_scope_mut(scope);
312 delayed_mutations.push((scope, path));
313 }
314 },
315 )
316 }
317}
318
319struct FileReferenceWithImport {
320 range: TextRange,
321 name: ast::NameLike,
322 import_data: Option<(ImportScope, ast::Path)>,
323}
324
325fn augment_references_with_imports(
326 ctx: &AssistContext<'_>,
327 references: Vec<FileReference>,
328 target_module: &hir::Module,
329) -> Vec<FileReferenceWithImport> {
330 let mut visited_modules = FxHashSet::default();
331
332 let edition = target_module.krate(ctx.db()).edition(ctx.db());
333 references
334 .into_iter()
335 .filter_map(|FileReference { range, name, .. }| {
336 let name = name.into_name_like()?;
337 ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
338 })
339 .map(|(range, name, ref_module)| {
340 let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
342 && !visited_modules.contains(&ref_module)
343 {
344 visited_modules.insert(ref_module);
345
346 ImportScope::find_insert_use_container(name.syntax(), &ctx.sema).and_then(
347 |import_scope| {
348 let cfg = ctx.config.find_path_config(
349 ctx.sema.is_nightly(target_module.krate(ctx.sema.db)),
350 );
351 let path = ref_module
352 .find_use_path(
353 ctx.sema.db,
354 ModuleDef::Module(*target_module),
355 ctx.config.insert_use.prefix_kind,
356 cfg,
357 )
358 .map(|mod_path| {
359 make::path_concat(
360 mod_path_to_ast(&mod_path, edition),
361 make::path_from_text("Bool"),
362 )
363 })?;
364
365 Some((import_scope, path))
366 },
367 )
368 } else {
369 None
370 };
371
372 FileReferenceWithImport { range, name, import_data }
373 })
374 .collect()
375}
376
377fn find_assignment_usage(name: &ast::NameLike) -> Option<ast::Expr> {
378 let bin_expr = name.syntax().ancestors().find_map(ast::BinExpr::cast)?;
379
380 if !bin_expr.lhs()?.syntax().descendants().contains(name.syntax()) {
381 cov_mark::hit!(dont_assign_incorrect_ref);
382 return None;
383 }
384
385 if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
386 bin_expr.rhs()
387 } else {
388 None
389 }
390}
391
392fn find_negated_usage(name: &ast::NameLike) -> Option<(ast::PrefixExpr, ast::Expr)> {
393 let prefix_expr = name.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
394
395 if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
396 cov_mark::hit!(dont_overwrite_expression_inside_negation);
397 return None;
398 }
399
400 if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
401 let inner_expr = prefix_expr.expr()?;
402 Some((prefix_expr, inner_expr))
403 } else {
404 None
405 }
406}
407
408fn find_record_expr_usage(
409 name: &ast::NameLike,
410 got_field: hir::Field,
411 target_definition: Definition,
412) -> Option<(ast::RecordExprField, ast::Expr)> {
413 let name_ref = name.as_name_ref()?;
414 let record_field = ast::RecordExprField::for_field_name(name_ref)?;
415 let initializer = record_field.expr()?;
416
417 match target_definition {
418 Definition::Field(expected_field) if got_field == expected_field => {
419 Some((record_field, initializer))
420 }
421 _ => None,
422 }
423}
424
425fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
426 let record_pat_field = name.syntax().parent().and_then(ast::RecordPatField::cast)?;
427 let pat = record_pat_field.pat()?;
428
429 match pat {
430 ast::Pat::IdentPat(_) | ast::Pat::LiteralPat(_) | ast::Pat::WildcardPat(_) => Some(pat),
431 _ => None,
432 }
433}
434
435fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
436 let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
437 const_.syntax().parent().and_then(ast::AssocItemList::cast)?;
438
439 Some((const_.ty()?, const_.body()?))
440}
441
442fn find_method_call_expr_usage(name: &ast::NameLike) -> Option<ast::Expr> {
443 let method_call = name.syntax().ancestors().find_map(ast::MethodCallExpr::cast)?;
444 let receiver = method_call.receiver()?;
445
446 if !receiver.syntax().descendants().contains(name.syntax()) {
447 return None;
448 }
449
450 Some(receiver)
451}
452
453fn add_enum_def(
455 edit: &mut SourceChangeBuilder,
456 ctx: &AssistContext<'_>,
457 usages: &UsageSearchResult,
458 target_node: SyntaxNode,
459 target_module: &hir::Module,
460) -> Option<()> {
461 let insert_before = node_to_insert_before(target_node);
462
463 if ctx
464 .sema
465 .scope(&insert_before)?
466 .module()
467 .scope(ctx.db(), Some(*target_module))
468 .iter()
469 .any(|(name, _)| name.as_str() == "Bool")
470 {
471 return None;
472 }
473
474 let make_enum_pub = usages
475 .iter()
476 .flat_map(|(_, refs)| refs)
477 .filter_map(|FileReference { name, .. }| {
478 let name = name.clone().into_name_like()?;
479 ctx.sema.scope(name.syntax()).map(|scope| scope.module())
480 })
481 .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
482 let enum_def = make_bool_enum(make_enum_pub);
483
484 let indent = IndentLevel::from_node(&insert_before);
485 enum_def.reindent_to(indent);
486
487 edit.insert(
488 insert_before.text_range().start(),
489 format!("{}\n\n{indent}", enum_def.syntax().text()),
490 );
491
492 Some(())
493}
494
495fn node_to_insert_before(target_node: SyntaxNode) -> SyntaxNode {
499 target_node
500 .ancestors()
501 .take_while(|it| !matches!(it.kind(), SyntaxKind::MODULE | SyntaxKind::SOURCE_FILE))
502 .filter(|it| ast::Item::can_cast(it.kind()))
503 .last()
504 .unwrap_or(target_node)
505}
506
507fn make_bool_enum(make_pub: bool) -> ast::Enum {
508 let derive_eq = make::attr_outer(make::meta_token_tree(
509 make::ext::ident_path("derive"),
510 make::token_tree(
511 T!['('],
512 vec![
513 NodeOrToken::Token(make::tokens::ident("PartialEq")),
514 NodeOrToken::Token(make::token(T![,])),
515 NodeOrToken::Token(make::tokens::single_space()),
516 NodeOrToken::Token(make::tokens::ident("Eq")),
517 ],
518 ),
519 ));
520 make::enum_(
521 [derive_eq],
522 if make_pub { Some(make::visibility_pub()) } else { None },
523 make::name("Bool"),
524 None,
525 None,
526 make::variant_list(vec![
527 make::variant(None, make::name("True"), None, None),
528 make::variant(None, make::name("False"), None, None),
529 ]),
530 )
531 .clone_for_update()
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 use crate::tests::{check_assist, check_assist_not_applicable};
539
540 #[test]
541 fn parameter_with_first_param_usage() {
542 check_assist(
543 convert_bool_to_enum,
544 r#"
545fn function($0foo: bool, bar: bool) {
546 if foo {
547 println!("foo");
548 }
549}
550"#,
551 r#"
552#[derive(PartialEq, Eq)]
553enum Bool { True, False }
554
555fn function(foo: Bool, bar: bool) {
556 if foo == Bool::True {
557 println!("foo");
558 }
559}
560"#,
561 )
562 }
563
564 #[test]
565 fn no_duplicate_enums() {
566 check_assist(
567 convert_bool_to_enum,
568 r#"
569#[derive(PartialEq, Eq)]
570enum Bool { True, False }
571
572fn function(foo: bool, $0bar: bool) {
573 if bar {
574 println!("bar");
575 }
576}
577"#,
578 r#"
579#[derive(PartialEq, Eq)]
580enum Bool { True, False }
581
582fn function(foo: bool, bar: Bool) {
583 if bar == Bool::True {
584 println!("bar");
585 }
586}
587"#,
588 )
589 }
590
591 #[test]
592 fn parameter_with_last_param_usage() {
593 check_assist(
594 convert_bool_to_enum,
595 r#"
596fn function(foo: bool, $0bar: bool) {
597 if bar {
598 println!("bar");
599 }
600}
601"#,
602 r#"
603#[derive(PartialEq, Eq)]
604enum Bool { True, False }
605
606fn function(foo: bool, bar: Bool) {
607 if bar == Bool::True {
608 println!("bar");
609 }
610}
611"#,
612 )
613 }
614
615 #[test]
616 fn parameter_with_middle_param_usage() {
617 check_assist(
618 convert_bool_to_enum,
619 r#"
620fn function(foo: bool, $0bar: bool, baz: bool) {
621 if bar {
622 println!("bar");
623 }
624}
625"#,
626 r#"
627#[derive(PartialEq, Eq)]
628enum Bool { True, False }
629
630fn function(foo: bool, bar: Bool, baz: bool) {
631 if bar == Bool::True {
632 println!("bar");
633 }
634}
635"#,
636 )
637 }
638
639 #[test]
640 fn parameter_with_closure_usage() {
641 check_assist(
642 convert_bool_to_enum,
643 r#"
644fn main() {
645 let foo = |$0bar: bool| bar;
646}
647"#,
648 r#"
649#[derive(PartialEq, Eq)]
650enum Bool { True, False }
651
652fn main() {
653 let foo = |bar: Bool| bar == Bool::True;
654}
655"#,
656 )
657 }
658
659 #[test]
660 fn local_variable_with_usage() {
661 check_assist(
662 convert_bool_to_enum,
663 r#"
664fn main() {
665 let $0foo = true;
666
667 if foo {
668 println!("foo");
669 }
670}
671"#,
672 r#"
673#[derive(PartialEq, Eq)]
674enum Bool { True, False }
675
676fn main() {
677 let foo = Bool::True;
678
679 if foo == Bool::True {
680 println!("foo");
681 }
682}
683"#,
684 )
685 }
686
687 #[test]
688 fn local_variable_with_usage_negated() {
689 cov_mark::check!(replaces_negation);
690 check_assist(
691 convert_bool_to_enum,
692 r#"
693fn main() {
694 let $0foo = true;
695
696 if !foo {
697 println!("foo");
698 }
699}
700"#,
701 r#"
702#[derive(PartialEq, Eq)]
703enum Bool { True, False }
704
705fn main() {
706 let foo = Bool::True;
707
708 if foo == Bool::False {
709 println!("foo");
710 }
711}
712"#,
713 )
714 }
715
716 #[test]
717 fn local_variable_with_type_annotation() {
718 cov_mark::check!(replaces_ty_annotation);
719 check_assist(
720 convert_bool_to_enum,
721 r#"
722fn main() {
723 let $0foo: bool = false;
724}
725"#,
726 r#"
727#[derive(PartialEq, Eq)]
728enum Bool { True, False }
729
730fn main() {
731 let foo: Bool = Bool::False;
732}
733"#,
734 )
735 }
736
737 #[test]
738 fn local_variable_with_non_literal_initializer() {
739 check_assist(
740 convert_bool_to_enum,
741 r#"
742fn main() {
743 let $0foo = 1 == 2;
744}
745"#,
746 r#"
747#[derive(PartialEq, Eq)]
748enum Bool { True, False }
749
750fn main() {
751 let foo = if 1 == 2 { Bool::True } else { Bool::False };
752}
753"#,
754 )
755 }
756
757 #[test]
758 fn local_variable_binexpr_usage() {
759 check_assist(
760 convert_bool_to_enum,
761 r#"
762fn main() {
763 let $0foo = false;
764 let bar = true;
765
766 if !foo && bar {
767 println!("foobar");
768 }
769}
770"#,
771 r#"
772#[derive(PartialEq, Eq)]
773enum Bool { True, False }
774
775fn main() {
776 let foo = Bool::False;
777 let bar = true;
778
779 if foo == Bool::False && bar {
780 println!("foobar");
781 }
782}
783"#,
784 )
785 }
786
787 #[test]
788 fn local_variable_unop_usage() {
789 check_assist(
790 convert_bool_to_enum,
791 r#"
792fn main() {
793 let $0foo = true;
794
795 if *&foo {
796 println!("foobar");
797 }
798}
799"#,
800 r#"
801#[derive(PartialEq, Eq)]
802enum Bool { True, False }
803
804fn main() {
805 let foo = Bool::True;
806
807 if *&foo == Bool::True {
808 println!("foobar");
809 }
810}
811"#,
812 )
813 }
814
815 #[test]
816 fn local_variable_assigned_later() {
817 cov_mark::check!(replaces_assignment);
818 check_assist(
819 convert_bool_to_enum,
820 r#"
821fn main() {
822 let $0foo: bool;
823 foo = true;
824}
825"#,
826 r#"
827#[derive(PartialEq, Eq)]
828enum Bool { True, False }
829
830fn main() {
831 let foo: Bool;
832 foo = Bool::True;
833}
834"#,
835 )
836 }
837
838 #[test]
839 fn local_variable_does_not_apply_recursively() {
840 check_assist(
841 convert_bool_to_enum,
842 r#"
843fn main() {
844 let $0foo = true;
845 let bar = !foo;
846
847 if bar {
848 println!("bar");
849 }
850}
851"#,
852 r#"
853#[derive(PartialEq, Eq)]
854enum Bool { True, False }
855
856fn main() {
857 let foo = Bool::True;
858 let bar = foo == Bool::False;
859
860 if bar {
861 println!("bar");
862 }
863}
864"#,
865 )
866 }
867
868 #[test]
869 fn local_variable_nested_in_negation() {
870 cov_mark::check!(dont_overwrite_expression_inside_negation);
871 check_assist(
872 convert_bool_to_enum,
873 r#"
874fn main() {
875 if !"foo".chars().any(|c| {
876 let $0foo = true;
877 foo
878 }) {
879 println!("foo");
880 }
881}
882"#,
883 r#"
884#[derive(PartialEq, Eq)]
885enum Bool { True, False }
886
887fn main() {
888 if !"foo".chars().any(|c| {
889 let foo = Bool::True;
890 foo == Bool::True
891 }) {
892 println!("foo");
893 }
894}
895"#,
896 )
897 }
898
899 #[test]
900 fn local_variable_non_bool() {
901 cov_mark::check!(not_applicable_non_bool_local);
902 check_assist_not_applicable(
903 convert_bool_to_enum,
904 r#"
905fn main() {
906 let $0foo = 1;
907}
908"#,
909 )
910 }
911
912 #[test]
913 fn local_variable_cursor_not_on_ident() {
914 check_assist_not_applicable(
915 convert_bool_to_enum,
916 r#"
917fn main() {
918 let foo = $0true;
919}
920"#,
921 )
922 }
923
924 #[test]
925 fn local_variable_non_ident_pat() {
926 check_assist_not_applicable(
927 convert_bool_to_enum,
928 r#"
929fn main() {
930 let ($0foo, bar) = (true, false);
931}
932"#,
933 )
934 }
935
936 #[test]
937 fn local_var_init_struct_usage() {
938 check_assist(
939 convert_bool_to_enum,
940 r#"
941struct Foo {
942 foo: bool,
943}
944
945fn main() {
946 let $0foo = true;
947 let s = Foo { foo };
948}
949"#,
950 r#"
951struct Foo {
952 foo: bool,
953}
954
955#[derive(PartialEq, Eq)]
956enum Bool { True, False }
957
958fn main() {
959 let foo = Bool::True;
960 let s = Foo { foo: foo == Bool::True };
961}
962"#,
963 )
964 }
965
966 #[test]
967 fn local_var_init_struct_usage_in_macro() {
968 check_assist(
969 convert_bool_to_enum,
970 r#"
971struct Struct {
972 boolean: bool,
973}
974
975macro_rules! identity {
976 ($body:expr) => {
977 $body
978 }
979}
980
981fn new() -> Struct {
982 let $0boolean = true;
983 identity![Struct { boolean }]
984}
985"#,
986 r#"
987struct Struct {
988 boolean: bool,
989}
990
991macro_rules! identity {
992 ($body:expr) => {
993 $body
994 }
995}
996
997#[derive(PartialEq, Eq)]
998enum Bool { True, False }
999
1000fn new() -> Struct {
1001 let boolean = Bool::True;
1002 identity![Struct { boolean: boolean == Bool::True }]
1003}
1004"#,
1005 )
1006 }
1007
1008 #[test]
1009 fn field_struct_basic() {
1010 cov_mark::check!(replaces_record_expr);
1011 check_assist(
1012 convert_bool_to_enum,
1013 r#"
1014struct Foo {
1015 $0bar: bool,
1016 baz: bool,
1017}
1018
1019fn main() {
1020 let foo = Foo { bar: true, baz: false };
1021
1022 if foo.bar {
1023 println!("foo");
1024 }
1025}
1026"#,
1027 r#"
1028#[derive(PartialEq, Eq)]
1029enum Bool { True, False }
1030
1031struct Foo {
1032 bar: Bool,
1033 baz: bool,
1034}
1035
1036fn main() {
1037 let foo = Foo { bar: Bool::True, baz: false };
1038
1039 if foo.bar == Bool::True {
1040 println!("foo");
1041 }
1042}
1043"#,
1044 )
1045 }
1046
1047 #[test]
1048 fn field_enum_basic() {
1049 cov_mark::check!(replaces_record_pat);
1050 check_assist(
1051 convert_bool_to_enum,
1052 r#"
1053enum Foo {
1054 Foo,
1055 Bar { $0bar: bool },
1056}
1057
1058fn main() {
1059 let foo = Foo::Bar { bar: true };
1060
1061 if let Foo::Bar { bar: baz } = foo {
1062 if baz {
1063 println!("foo");
1064 }
1065 }
1066}
1067"#,
1068 r#"
1069#[derive(PartialEq, Eq)]
1070enum Bool { True, False }
1071
1072enum Foo {
1073 Foo,
1074 Bar { bar: Bool },
1075}
1076
1077fn main() {
1078 let foo = Foo::Bar { bar: Bool::True };
1079
1080 if let Foo::Bar { bar: baz } = foo {
1081 if baz == Bool::True {
1082 println!("foo");
1083 }
1084 }
1085}
1086"#,
1087 )
1088 }
1089
1090 #[test]
1091 fn field_enum_cross_file() {
1092 check_assist(
1094 convert_bool_to_enum,
1095 r#"
1096//- /foo.rs
1097pub enum Foo {
1098 Foo,
1099 Bar { $0bar: bool },
1100}
1101
1102fn foo() {
1103 let foo = Foo::Bar { bar: true };
1104}
1105
1106//- /main.rs
1107use foo::Foo;
1108
1109mod foo;
1110
1111fn main() {
1112 let foo = Foo::Bar { bar: false };
1113}
1114"#,
1115 r#"
1116//- /foo.rs
1117#[derive(PartialEq, Eq)]
1118pub enum Bool { True, False }
1119
1120pub enum Foo {
1121 Foo,
1122 Bar { bar: Bool },
1123}
1124
1125fn foo() {
1126 let foo = Foo::Bar { bar: Bool::True };
1127}
1128
1129//- /main.rs
1130use foo::{Bool, Foo};
1131
1132mod foo;
1133
1134fn main() {
1135 let foo = Foo::Bar { bar: Bool::False };
1136}
1137"#,
1138 )
1139 }
1140
1141 #[test]
1142 fn field_enum_shorthand() {
1143 cov_mark::check!(replaces_record_pat_shorthand);
1144 check_assist(
1145 convert_bool_to_enum,
1146 r#"
1147enum Foo {
1148 Foo,
1149 Bar { $0bar: bool },
1150}
1151
1152fn main() {
1153 let foo = Foo::Bar { bar: true };
1154
1155 match foo {
1156 Foo::Bar { bar } => {
1157 if bar {
1158 println!("foo");
1159 }
1160 }
1161 _ => (),
1162 }
1163}
1164"#,
1165 r#"
1166#[derive(PartialEq, Eq)]
1167enum Bool { True, False }
1168
1169enum Foo {
1170 Foo,
1171 Bar { bar: Bool },
1172}
1173
1174fn main() {
1175 let foo = Foo::Bar { bar: Bool::True };
1176
1177 match foo {
1178 Foo::Bar { bar } => {
1179 if bar == Bool::True {
1180 println!("foo");
1181 }
1182 }
1183 _ => (),
1184 }
1185}
1186"#,
1187 )
1188 }
1189
1190 #[test]
1191 fn field_enum_replaces_literal_patterns() {
1192 cov_mark::check!(replaces_literal_pat);
1193 check_assist(
1194 convert_bool_to_enum,
1195 r#"
1196enum Foo {
1197 Foo,
1198 Bar { $0bar: bool },
1199}
1200
1201fn main() {
1202 let foo = Foo::Bar { bar: true };
1203
1204 if let Foo::Bar { bar: true } = foo {
1205 println!("foo");
1206 }
1207}
1208"#,
1209 r#"
1210#[derive(PartialEq, Eq)]
1211enum Bool { True, False }
1212
1213enum Foo {
1214 Foo,
1215 Bar { bar: Bool },
1216}
1217
1218fn main() {
1219 let foo = Foo::Bar { bar: Bool::True };
1220
1221 if let Foo::Bar { bar: Bool::True } = foo {
1222 println!("foo");
1223 }
1224}
1225"#,
1226 )
1227 }
1228
1229 #[test]
1230 fn field_enum_keeps_wildcard_patterns() {
1231 check_assist(
1232 convert_bool_to_enum,
1233 r#"
1234enum Foo {
1235 Foo,
1236 Bar { $0bar: bool },
1237}
1238
1239fn main() {
1240 let foo = Foo::Bar { bar: true };
1241
1242 if let Foo::Bar { bar: _ } = foo {
1243 println!("foo");
1244 }
1245}
1246"#,
1247 r#"
1248#[derive(PartialEq, Eq)]
1249enum Bool { True, False }
1250
1251enum Foo {
1252 Foo,
1253 Bar { bar: Bool },
1254}
1255
1256fn main() {
1257 let foo = Foo::Bar { bar: Bool::True };
1258
1259 if let Foo::Bar { bar: _ } = foo {
1260 println!("foo");
1261 }
1262}
1263"#,
1264 )
1265 }
1266
1267 #[test]
1268 fn field_union_basic() {
1269 check_assist(
1270 convert_bool_to_enum,
1271 r#"
1272union Foo {
1273 $0foo: bool,
1274 bar: usize,
1275}
1276
1277fn main() {
1278 let foo = Foo { foo: true };
1279
1280 if unsafe { foo.foo } {
1281 println!("foo");
1282 }
1283}
1284"#,
1285 r#"
1286#[derive(PartialEq, Eq)]
1287enum Bool { True, False }
1288
1289union Foo {
1290 foo: Bool,
1291 bar: usize,
1292}
1293
1294fn main() {
1295 let foo = Foo { foo: Bool::True };
1296
1297 if unsafe { foo.foo == Bool::True } {
1298 println!("foo");
1299 }
1300}
1301"#,
1302 )
1303 }
1304
1305 #[test]
1306 fn field_negated() {
1307 check_assist(
1308 convert_bool_to_enum,
1309 r#"
1310struct Foo {
1311 $0bar: bool,
1312}
1313
1314fn main() {
1315 let foo = Foo { bar: false };
1316
1317 if !foo.bar {
1318 println!("foo");
1319 }
1320}
1321"#,
1322 r#"
1323#[derive(PartialEq, Eq)]
1324enum Bool { True, False }
1325
1326struct Foo {
1327 bar: Bool,
1328}
1329
1330fn main() {
1331 let foo = Foo { bar: Bool::False };
1332
1333 if foo.bar == Bool::False {
1334 println!("foo");
1335 }
1336}
1337"#,
1338 )
1339 }
1340
1341 #[test]
1342 fn field_in_mod_properly_indented() {
1343 check_assist(
1344 convert_bool_to_enum,
1345 r#"
1346mod foo {
1347 struct Bar {
1348 $0baz: bool,
1349 }
1350
1351 impl Bar {
1352 fn new(baz: bool) -> Self {
1353 Self { baz }
1354 }
1355 }
1356}
1357"#,
1358 r#"
1359mod foo {
1360 #[derive(PartialEq, Eq)]
1361 enum Bool { True, False }
1362
1363 struct Bar {
1364 baz: Bool,
1365 }
1366
1367 impl Bar {
1368 fn new(baz: bool) -> Self {
1369 Self { baz: if baz { Bool::True } else { Bool::False } }
1370 }
1371 }
1372}
1373"#,
1374 )
1375 }
1376
1377 #[test]
1378 fn field_multiple_initializations() {
1379 check_assist(
1380 convert_bool_to_enum,
1381 r#"
1382struct Foo {
1383 $0bar: bool,
1384 baz: bool,
1385}
1386
1387fn main() {
1388 let foo1 = Foo { bar: true, baz: false };
1389 let foo2 = Foo { bar: false, baz: false };
1390
1391 if foo1.bar && foo2.bar {
1392 println!("foo");
1393 }
1394}
1395"#,
1396 r#"
1397#[derive(PartialEq, Eq)]
1398enum Bool { True, False }
1399
1400struct Foo {
1401 bar: Bool,
1402 baz: bool,
1403}
1404
1405fn main() {
1406 let foo1 = Foo { bar: Bool::True, baz: false };
1407 let foo2 = Foo { bar: Bool::False, baz: false };
1408
1409 if foo1.bar == Bool::True && foo2.bar == Bool::True {
1410 println!("foo");
1411 }
1412}
1413"#,
1414 )
1415 }
1416
1417 #[test]
1418 fn field_assigned_to_another() {
1419 cov_mark::check!(dont_assign_incorrect_ref);
1420 check_assist(
1421 convert_bool_to_enum,
1422 r#"
1423struct Foo {
1424 $0foo: bool,
1425}
1426
1427struct Bar {
1428 bar: bool,
1429}
1430
1431fn main() {
1432 let foo = Foo { foo: true };
1433 let mut bar = Bar { bar: true };
1434
1435 bar.bar = foo.foo;
1436}
1437"#,
1438 r#"
1439#[derive(PartialEq, Eq)]
1440enum Bool { True, False }
1441
1442struct Foo {
1443 foo: Bool,
1444}
1445
1446struct Bar {
1447 bar: bool,
1448}
1449
1450fn main() {
1451 let foo = Foo { foo: Bool::True };
1452 let mut bar = Bar { bar: true };
1453
1454 bar.bar = foo.foo == Bool::True;
1455}
1456"#,
1457 )
1458 }
1459
1460 #[test]
1461 fn field_initialized_with_other() {
1462 check_assist(
1463 convert_bool_to_enum,
1464 r#"
1465struct Foo {
1466 $0foo: bool,
1467}
1468
1469struct Bar {
1470 bar: bool,
1471}
1472
1473fn main() {
1474 let foo = Foo { foo: true };
1475 let bar = Bar { bar: foo.foo };
1476}
1477"#,
1478 r#"
1479#[derive(PartialEq, Eq)]
1480enum Bool { True, False }
1481
1482struct Foo {
1483 foo: Bool,
1484}
1485
1486struct Bar {
1487 bar: bool,
1488}
1489
1490fn main() {
1491 let foo = Foo { foo: Bool::True };
1492 let bar = Bar { bar: foo.foo == Bool::True };
1493}
1494"#,
1495 )
1496 }
1497
1498 #[test]
1499 fn field_method_chain_usage() {
1500 check_assist(
1501 convert_bool_to_enum,
1502 r#"
1503struct Foo {
1504 $0bool: bool,
1505}
1506
1507fn main() {
1508 let foo = Foo { bool: true };
1509
1510 foo.bool.then(|| 2);
1511}
1512"#,
1513 r#"
1514#[derive(PartialEq, Eq)]
1515enum Bool { True, False }
1516
1517struct Foo {
1518 bool: Bool,
1519}
1520
1521fn main() {
1522 let foo = Foo { bool: Bool::True };
1523
1524 (foo.bool == Bool::True).then(|| 2);
1525}
1526"#,
1527 )
1528 }
1529
1530 #[test]
1531 fn field_in_macro() {
1532 check_assist(
1533 convert_bool_to_enum,
1534 r#"
1535struct Struct {
1536 $0boolean: bool,
1537}
1538
1539fn boolean(x: Struct) {
1540 let Struct { boolean } = x;
1541}
1542
1543macro_rules! identity { ($body:expr) => { $body } }
1544
1545fn new() -> Struct {
1546 identity!(Struct { boolean: true })
1547}
1548"#,
1549 r#"
1550#[derive(PartialEq, Eq)]
1551enum Bool { True, False }
1552
1553struct Struct {
1554 boolean: Bool,
1555}
1556
1557fn boolean(x: Struct) {
1558 let Struct { boolean } = x;
1559}
1560
1561macro_rules! identity { ($body:expr) => { $body } }
1562
1563fn new() -> Struct {
1564 identity!(Struct { boolean: Bool::True })
1565}
1566"#,
1567 )
1568 }
1569
1570 #[test]
1571 fn field_non_bool() {
1572 cov_mark::check!(not_applicable_non_bool_field);
1573 check_assist_not_applicable(
1574 convert_bool_to_enum,
1575 r#"
1576struct Foo {
1577 $0bar: usize,
1578}
1579
1580fn main() {
1581 let foo = Foo { bar: 1 };
1582}
1583"#,
1584 )
1585 }
1586
1587 #[test]
1588 fn const_basic() {
1589 check_assist(
1590 convert_bool_to_enum,
1591 r#"
1592const $0FOO: bool = false;
1593
1594fn main() {
1595 if FOO {
1596 println!("foo");
1597 }
1598}
1599"#,
1600 r#"
1601#[derive(PartialEq, Eq)]
1602enum Bool { True, False }
1603
1604const FOO: Bool = Bool::False;
1605
1606fn main() {
1607 if FOO == Bool::True {
1608 println!("foo");
1609 }
1610}
1611"#,
1612 )
1613 }
1614
1615 #[test]
1616 fn const_in_module() {
1617 check_assist(
1618 convert_bool_to_enum,
1619 r#"
1620fn main() {
1621 if foo::FOO {
1622 println!("foo");
1623 }
1624}
1625
1626mod foo {
1627 pub const $0FOO: bool = true;
1628}
1629"#,
1630 r#"
1631use foo::Bool;
1632
1633fn main() {
1634 if foo::FOO == Bool::True {
1635 println!("foo");
1636 }
1637}
1638
1639mod foo {
1640 #[derive(PartialEq, Eq)]
1641 pub enum Bool { True, False }
1642
1643 pub const FOO: Bool = Bool::True;
1644}
1645"#,
1646 )
1647 }
1648
1649 #[test]
1650 fn const_in_module_with_import() {
1651 check_assist(
1652 convert_bool_to_enum,
1653 r#"
1654fn main() {
1655 use foo::FOO;
1656
1657 if FOO {
1658 println!("foo");
1659 }
1660}
1661
1662mod foo {
1663 pub const $0FOO: bool = true;
1664}
1665"#,
1666 r#"
1667use foo::Bool;
1668
1669fn main() {
1670 use foo::FOO;
1671
1672 if FOO == Bool::True {
1673 println!("foo");
1674 }
1675}
1676
1677mod foo {
1678 #[derive(PartialEq, Eq)]
1679 pub enum Bool { True, False }
1680
1681 pub const FOO: Bool = Bool::True;
1682}
1683"#,
1684 )
1685 }
1686
1687 #[test]
1688 fn const_cross_file() {
1689 check_assist(
1690 convert_bool_to_enum,
1691 r#"
1692//- /main.rs
1693mod foo;
1694
1695fn main() {
1696 if foo::FOO {
1697 println!("foo");
1698 }
1699}
1700
1701//- /foo.rs
1702pub const $0FOO: bool = true;
1703"#,
1704 r#"
1705//- /main.rs
1706use foo::Bool;
1707
1708mod foo;
1709
1710fn main() {
1711 if foo::FOO == Bool::True {
1712 println!("foo");
1713 }
1714}
1715
1716//- /foo.rs
1717#[derive(PartialEq, Eq)]
1718pub enum Bool { True, False }
1719
1720pub const FOO: Bool = Bool::True;
1721"#,
1722 )
1723 }
1724
1725 #[test]
1726 fn const_cross_file_and_module() {
1727 check_assist(
1728 convert_bool_to_enum,
1729 r#"
1730//- /main.rs
1731mod foo;
1732
1733fn main() {
1734 use foo::bar;
1735
1736 if bar::BAR {
1737 println!("foo");
1738 }
1739}
1740
1741//- /foo.rs
1742pub mod bar {
1743 pub const $0BAR: bool = false;
1744}
1745"#,
1746 r#"
1747//- /main.rs
1748use foo::bar::Bool;
1749
1750mod foo;
1751
1752fn main() {
1753 use foo::bar;
1754
1755 if bar::BAR == Bool::True {
1756 println!("foo");
1757 }
1758}
1759
1760//- /foo.rs
1761pub mod bar {
1762 #[derive(PartialEq, Eq)]
1763 pub enum Bool { True, False }
1764
1765 pub const BAR: Bool = Bool::False;
1766}
1767"#,
1768 )
1769 }
1770
1771 #[test]
1772 fn const_in_impl_cross_file() {
1773 check_assist(
1774 convert_bool_to_enum,
1775 r#"
1776//- /main.rs
1777mod foo;
1778
1779struct Foo;
1780
1781impl Foo {
1782 pub const $0BOOL: bool = true;
1783}
1784
1785//- /foo.rs
1786use crate::Foo;
1787
1788fn foo() -> bool {
1789 Foo::BOOL
1790}
1791"#,
1792 r#"
1793//- /main.rs
1794mod foo;
1795
1796struct Foo;
1797
1798#[derive(PartialEq, Eq)]
1799pub enum Bool { True, False }
1800
1801impl Foo {
1802 pub const BOOL: Bool = Bool::True;
1803}
1804
1805//- /foo.rs
1806use crate::{Bool, Foo};
1807
1808fn foo() -> bool {
1809 Foo::BOOL == Bool::True
1810}
1811"#,
1812 )
1813 }
1814
1815 #[test]
1816 fn const_in_trait() {
1817 check_assist(
1818 convert_bool_to_enum,
1819 r#"
1820trait Foo {
1821 const $0BOOL: bool;
1822}
1823
1824impl Foo for usize {
1825 const BOOL: bool = true;
1826}
1827
1828fn main() {
1829 if <usize as Foo>::BOOL {
1830 println!("foo");
1831 }
1832}
1833"#,
1834 r#"
1835#[derive(PartialEq, Eq)]
1836enum Bool { True, False }
1837
1838trait Foo {
1839 const BOOL: Bool;
1840}
1841
1842impl Foo for usize {
1843 const BOOL: Bool = Bool::True;
1844}
1845
1846fn main() {
1847 if <usize as Foo>::BOOL == Bool::True {
1848 println!("foo");
1849 }
1850}
1851"#,
1852 )
1853 }
1854
1855 #[test]
1856 fn const_non_bool() {
1857 cov_mark::check!(not_applicable_non_bool_const);
1858 check_assist_not_applicable(
1859 convert_bool_to_enum,
1860 r#"
1861const $0FOO: &str = "foo";
1862
1863fn main() {
1864 println!("{FOO}");
1865}
1866"#,
1867 )
1868 }
1869
1870 #[test]
1871 fn static_basic() {
1872 check_assist(
1873 convert_bool_to_enum,
1874 r#"
1875static mut $0BOOL: bool = true;
1876
1877fn main() {
1878 unsafe { BOOL = false };
1879 if unsafe { BOOL } {
1880 println!("foo");
1881 }
1882}
1883"#,
1884 r#"
1885#[derive(PartialEq, Eq)]
1886enum Bool { True, False }
1887
1888static mut BOOL: Bool = Bool::True;
1889
1890fn main() {
1891 unsafe { BOOL = Bool::False };
1892 if unsafe { BOOL == Bool::True } {
1893 println!("foo");
1894 }
1895}
1896"#,
1897 )
1898 }
1899
1900 #[test]
1901 fn static_non_bool() {
1902 cov_mark::check!(not_applicable_non_bool_static);
1903 check_assist_not_applicable(
1904 convert_bool_to_enum,
1905 r#"
1906static mut $0FOO: usize = 0;
1907
1908fn main() {
1909 if unsafe { FOO } == 0 {
1910 println!("foo");
1911 }
1912}
1913"#,
1914 )
1915 }
1916
1917 #[test]
1918 fn not_applicable_to_other_names() {
1919 check_assist_not_applicable(convert_bool_to_enum, "fn $0main() {}")
1920 }
1921}