1use std::collections::VecDeque;
2
3use ide_db::{
4 assists::GroupLabel,
5 famous_defs::FamousDefs,
6 syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
7};
8use syntax::{
9 NodeOrToken, SyntaxKind, T,
10 ast::{
11 self, AstNode,
12 Expr::BinExpr,
13 HasArgList,
14 prec::{ExprPrecedence, precedence},
15 syntax_factory::SyntaxFactory,
16 },
17 syntax_editor::{Position, SyntaxEditor},
18};
19
20use crate::{AssistContext, AssistId, Assists, utils::invert_boolean_expression};
21
22pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
41 let mut bin_expr = if let Some(not) = ctx.find_token_syntax_at_offset(T![!])
42 && let Some(NodeOrToken::Node(next)) = not.next_sibling_or_token()
43 && let Some(paren) = ast::ParenExpr::cast(next)
44 && let Some(ast::Expr::BinExpr(bin_expr)) = paren.expr()
45 {
46 bin_expr
47 } else {
48 let bin_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
49 let op_range = bin_expr.op_token()?.text_range();
50
51 if !op_range.contains_range(ctx.selection_trimmed()) {
53 return None;
54 }
55
56 bin_expr
57 };
58
59 let op = bin_expr.op_kind()?;
60 let op_range = bin_expr.op_token()?.text_range();
61
62 while let Some(parent_expr) = bin_expr.syntax().parent().and_then(ast::BinExpr::cast) {
64 match parent_expr.op_kind() {
65 Some(parent_op) if parent_op == op => {
66 bin_expr = parent_expr;
67 }
68 _ => break,
69 }
70 }
71
72 let op = bin_expr.op_kind()?;
73 let (inv_token, prec) = match op {
74 ast::BinaryOp::LogicOp(ast::LogicOp::And) => (SyntaxKind::PIPE2, ExprPrecedence::LOr),
75 ast::BinaryOp::LogicOp(ast::LogicOp::Or) => (SyntaxKind::AMP2, ExprPrecedence::LAnd),
76 _ => return None,
77 };
78
79 let make = SyntaxFactory::with_mappings();
80
81 let demorganed = bin_expr.clone_subtree();
82 let mut editor = SyntaxEditor::new(demorganed.syntax().clone());
83 editor.replace(demorganed.op_token()?, make.token(inv_token));
84
85 let mut exprs = VecDeque::from([
86 (bin_expr.lhs()?, demorganed.lhs()?, prec),
87 (bin_expr.rhs()?, demorganed.rhs()?, prec),
88 ]);
89
90 while let Some((expr, demorganed, prec)) = exprs.pop_front() {
91 if let BinExpr(bin_expr) = &expr {
92 if let BinExpr(cbin_expr) = &demorganed {
93 if op == bin_expr.op_kind()? {
94 editor.replace(cbin_expr.op_token()?, make.token(inv_token));
95 exprs.push_back((bin_expr.lhs()?, cbin_expr.lhs()?, prec));
96 exprs.push_back((bin_expr.rhs()?, cbin_expr.rhs()?, prec));
97 } else {
98 let mut inv = invert_boolean_expression(&make, expr);
99 if precedence(&inv).needs_parentheses_in(prec) {
100 inv = make.expr_paren(inv).into();
101 }
102 editor.replace(demorganed.syntax(), inv.syntax());
103 }
104 } else {
105 return None;
106 }
107 } else {
108 let mut inv = invert_boolean_expression(&make, demorganed.clone());
109 if precedence(&inv).needs_parentheses_in(prec) {
110 inv = make.expr_paren(inv).into();
111 }
112 editor.replace(demorganed.syntax(), inv.syntax());
113 }
114 }
115
116 editor.add_mappings(make.finish_with_mappings());
117 let edit = editor.finish();
118 let demorganed = ast::Expr::cast(edit.new_root().clone())?;
119
120 acc.add_group(
121 &GroupLabel("Apply De Morgan's law".to_owned()),
122 AssistId::refactor_rewrite("apply_demorgan"),
123 "Apply De Morgan's law",
124 op_range,
125 |builder| {
126 let make = SyntaxFactory::with_mappings();
127 let (target_node, result_expr) = if let Some(neg_expr) = bin_expr
128 .syntax()
129 .parent()
130 .and_then(ast::ParenExpr::cast)
131 .and_then(|paren_expr| paren_expr.syntax().parent())
132 .and_then(ast::PrefixExpr::cast)
133 .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
134 {
135 cov_mark::hit!(demorgan_double_negation);
136 (ast::Expr::from(neg_expr).syntax().clone(), demorganed)
137 } else if let Some(paren_expr) =
138 bin_expr.syntax().parent().and_then(ast::ParenExpr::cast)
139 {
140 cov_mark::hit!(demorgan_double_parens);
141 (paren_expr.syntax().clone(), add_bang_paren(&make, demorganed))
142 } else {
143 (bin_expr.syntax().clone(), add_bang_paren(&make, demorganed))
144 };
145
146 let final_expr = if target_node
147 .parent()
148 .is_some_and(|p| result_expr.needs_parens_in_place_of(&p, &target_node))
149 {
150 cov_mark::hit!(demorgan_keep_parens_for_op_precedence2);
151 make.expr_paren(result_expr).into()
152 } else {
153 result_expr
154 };
155
156 let mut editor = builder.make_editor(&target_node);
157 editor.replace(&target_node, final_expr.syntax());
158 editor.add_mappings(make.finish_with_mappings());
159 builder.add_file_edits(ctx.vfs_file_id(), editor);
160 },
161 )
162}
163
164pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
192 let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
193 let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
194
195 let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
196 let closure_body = closure_expr.body()?.clone_for_update();
197
198 let op_range = method_call.syntax().text_range();
199 let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
200 acc.add_group(
201 &GroupLabel("Apply De Morgan's law".to_owned()),
202 AssistId::refactor_rewrite("apply_demorgan_iterator"),
203 label,
204 op_range,
205 |builder| {
206 let make = SyntaxFactory::with_mappings();
207 let mut editor = builder.make_editor(method_call.syntax());
208 let new_name = match name.text().as_str() {
210 "all" => make.name_ref("any"),
211 "any" => make.name_ref("all"),
212 _ => unreachable!(),
213 };
214 editor.replace(name.syntax(), new_name.syntax());
215
216 let tail_cb = &mut |e: &_| tail_cb_impl(&mut editor, &make, e);
218 walk_expr(&closure_body, &mut |expr| {
219 if let ast::Expr::ReturnExpr(ret_expr) = expr
220 && let Some(ret_expr_arg) = &ret_expr.expr()
221 {
222 for_each_tail_expr(ret_expr_arg, tail_cb);
223 }
224 });
225 for_each_tail_expr(&closure_body, tail_cb);
226
227 if let Some(prefix_expr) = method_call
229 .syntax()
230 .parent()
231 .and_then(ast::PrefixExpr::cast)
232 .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
233 {
234 editor.delete(
235 prefix_expr.op_token().expect("prefix expression always has an operator"),
236 );
237 } else {
238 editor.insert(Position::before(method_call.syntax()), make.token(SyntaxKind::BANG));
239 }
240
241 editor.add_mappings(make.finish_with_mappings());
242 builder.add_file_edits(ctx.vfs_file_id(), editor);
243 },
244 )
245}
246
247fn validate_method_call_expr(
249 ctx: &AssistContext<'_>,
250 method_call: &ast::MethodCallExpr,
251) -> Option<(ast::NameRef, ast::Expr)> {
252 let name_ref = method_call.name_ref()?;
253 if name_ref.text() != "all" && name_ref.text() != "any" {
254 return None;
255 }
256 let arg_expr = method_call.arg_list()?.args().next()?;
257
258 let sema = &ctx.sema;
259
260 let receiver = method_call.receiver()?;
261 let it_type = sema.type_of_expr(&receiver)?.adjusted();
262 let module = sema.scope(receiver.syntax())?.module();
263 let krate = module.krate(ctx.db());
264
265 let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
266 it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
267}
268
269fn tail_cb_impl(editor: &mut SyntaxEditor, make: &SyntaxFactory, e: &ast::Expr) {
270 match e {
271 ast::Expr::BreakExpr(break_expr) => {
272 if let Some(break_expr_arg) = break_expr.expr() {
273 for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(editor, make, e))
274 }
275 }
276 ast::Expr::ReturnExpr(_) => {
277 }
279 e => {
280 let inverted_body = invert_boolean_expression(make, e.clone());
281 editor.replace(e.syntax(), inverted_body.syntax());
282 }
283 }
284}
285
286fn add_bang_paren(make: &SyntaxFactory, expr: ast::Expr) -> ast::Expr {
288 make.expr_prefix(T![!], make.expr_paren(expr).into()).into()
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::tests::{check_assist, check_assist_not_applicable};
295
296 #[test]
297 fn demorgan_handles_leq() {
298 check_assist(
299 apply_demorgan,
300 r#"
301struct S;
302fn f() { S < S &&$0 S <= S }
303"#,
304 r#"
305struct S;
306fn f() { !(S >= S || S > S) }
307"#,
308 );
309 }
310
311 #[test]
312 fn demorgan_handles_geq() {
313 check_assist(
314 apply_demorgan,
315 r#"
316struct S;
317fn f() { S > S &&$0 S >= S }
318"#,
319 r#"
320struct S;
321fn f() { !(S <= S || S < S) }
322"#,
323 );
324 }
325
326 #[test]
327 fn demorgan_turns_and_into_or() {
328 check_assist(apply_demorgan, "fn f() { !x &&$0 !x }", "fn f() { !(x || x) }")
329 }
330
331 #[test]
332 fn demorgan_turns_or_into_and() {
333 check_assist(apply_demorgan, "fn f() { !x ||$0 !x }", "fn f() { !(x && x) }")
334 }
335
336 #[test]
337 fn demorgan_removes_inequality() {
338 check_assist(apply_demorgan, "fn f() { x != x ||$0 !x }", "fn f() { !(x == x && x) }")
339 }
340
341 #[test]
342 fn demorgan_general_case() {
343 check_assist(apply_demorgan, "fn f() { x ||$0 x }", "fn f() { !(!x && !x) }")
344 }
345
346 #[test]
347 fn demorgan_multiple_terms() {
348 check_assist(apply_demorgan, "fn f() { x ||$0 y || z }", "fn f() { !(!x && !y && !z) }");
349 check_assist(apply_demorgan, "fn f() { x || y ||$0 z }", "fn f() { !(!x && !y && !z) }");
350 }
351
352 #[test]
353 fn demorgan_doesnt_apply_with_cursor_not_on_op() {
354 check_assist_not_applicable(apply_demorgan, "fn f() { $0 !x || !x }")
355 }
356
357 #[test]
358 fn demorgan_doesnt_double_negation() {
359 cov_mark::check!(demorgan_double_negation);
360 check_assist(apply_demorgan, "fn f() { !(x ||$0 x) }", "fn f() { !x && !x }")
361 }
362
363 #[test]
364 fn demorgan_doesnt_double_parens() {
365 cov_mark::check!(demorgan_double_parens);
366 check_assist(apply_demorgan, "fn f() { (x ||$0 x) }", "fn f() { !(!x && !x) }")
367 }
368
369 #[test]
370 fn demorgan_doesnt_hang() {
371 check_assist(
372 apply_demorgan,
373 "fn f() { 1 || 3 &&$0 4 || 5 }",
374 "fn f() { 1 || !(!3 || !4) || 5 }",
375 )
376 }
377
378 #[test]
379 fn demorgan_on_not() {
380 check_assist(
381 apply_demorgan,
382 "fn f() { $0!(1 || 3 && 4 || 5) }",
383 "fn f() { !1 && !(3 && 4) && !5 }",
384 )
385 }
386
387 #[test]
388 fn demorgan_keep_pars_for_op_precedence() {
389 check_assist(
390 apply_demorgan,
391 "fn main() {
392 let _ = !(!a ||$0 !(b || c));
393}
394",
395 "fn main() {
396 let _ = a && (b || c);
397}
398",
399 );
400 }
401
402 #[test]
403 fn demorgan_keep_pars_for_op_precedence2() {
404 cov_mark::check!(demorgan_keep_parens_for_op_precedence2);
405 check_assist(
406 apply_demorgan,
407 "fn f() { (a && !(b &&$0 c); }",
408 "fn f() { (a && (!b || !c); }",
409 );
410 }
411
412 #[test]
413 fn demorgan_keep_pars_for_op_precedence3() {
414 check_assist(
415 apply_demorgan,
416 "fn f() { (a || !(b &&$0 c); }",
417 "fn f() { (a || (!b || !c); }",
418 );
419 }
420
421 #[test]
422 fn demorgan_keeps_pars_in_eq_precedence() {
423 check_assist(
424 apply_demorgan,
425 "fn() { let x = a && !(!b |$0| !c); }",
426 "fn() { let x = a && (b && c); }",
427 )
428 }
429
430 #[test]
431 fn demorgan_removes_pars_for_op_precedence2() {
432 check_assist(apply_demorgan, "fn f() { (a || !(b ||$0 c); }", "fn f() { (a || !b && !c; }");
433 }
434
435 #[test]
436 fn demorgan_iterator_any_all_reverse() {
437 check_assist(
438 apply_demorgan_iterator,
439 r#"
440//- minicore: iterator
441fn main() {
442 let arr = [1, 2, 3];
443 if arr.into_iter().all(|num| num $0!= 4) {
444 println!("foo");
445 }
446}
447"#,
448 r#"
449fn main() {
450 let arr = [1, 2, 3];
451 if !arr.into_iter().any(|num| num == 4) {
452 println!("foo");
453 }
454}
455"#,
456 );
457 }
458
459 #[test]
460 fn demorgan_iterator_all_any() {
461 check_assist(
462 apply_demorgan_iterator,
463 r#"
464//- minicore: iterator
465fn main() {
466 let arr = [1, 2, 3];
467 if !arr.into_iter().$0all(|num| num > 3) {
468 println!("foo");
469 }
470}
471"#,
472 r#"
473fn main() {
474 let arr = [1, 2, 3];
475 if arr.into_iter().any(|num| num <= 3) {
476 println!("foo");
477 }
478}
479"#,
480 );
481 }
482
483 #[test]
484 fn demorgan_iterator_multiple_terms() {
485 check_assist(
486 apply_demorgan_iterator,
487 r#"
488//- minicore: iterator
489fn main() {
490 let arr = [1, 2, 3];
491 if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
492 println!("foo");
493 }
494}
495"#,
496 r#"
497fn main() {
498 let arr = [1, 2, 3];
499 if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
500 println!("foo");
501 }
502}
503"#,
504 );
505 }
506
507 #[test]
508 fn demorgan_iterator_double_negation() {
509 check_assist(
510 apply_demorgan_iterator,
511 r#"
512//- minicore: iterator
513fn main() {
514 let arr = [1, 2, 3];
515 if !arr.into_iter().$0all(|num| !(num > 3)) {
516 println!("foo");
517 }
518}
519"#,
520 r#"
521fn main() {
522 let arr = [1, 2, 3];
523 if arr.into_iter().any(|num| num > 3) {
524 println!("foo");
525 }
526}
527"#,
528 );
529 }
530
531 #[test]
532 fn demorgan_iterator_double_parens() {
533 check_assist(
534 apply_demorgan_iterator,
535 r#"
536//- minicore: iterator
537fn main() {
538 let arr = [1, 2, 3];
539 if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
540 println!("foo");
541 }
542}
543"#,
544 r#"
545fn main() {
546 let arr = [1, 2, 3];
547 if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
548 println!("foo");
549 }
550}
551"#,
552 );
553 }
554
555 #[test]
556 fn demorgan_iterator_multiline() {
557 check_assist(
558 apply_demorgan_iterator,
559 r#"
560//- minicore: iterator
561fn main() {
562 let arr = [1, 2, 3];
563 if arr
564 .into_iter()
565 .all$0(|num| !num.is_negative())
566 {
567 println!("foo");
568 }
569}
570"#,
571 r#"
572fn main() {
573 let arr = [1, 2, 3];
574 if !arr
575 .into_iter()
576 .any(|num| num.is_negative())
577 {
578 println!("foo");
579 }
580}
581"#,
582 );
583 }
584
585 #[test]
586 fn demorgan_iterator_block_closure() {
587 check_assist(
588 apply_demorgan_iterator,
589 r#"
590//- minicore: iterator
591fn main() {
592 let arr = [-1, 1, 2, 3];
593 if arr.into_iter().all(|num: i32| {
594 $0if num.is_positive() {
595 num <= 3
596 } else {
597 num >= -1
598 }
599 }) {
600 println!("foo");
601 }
602}
603"#,
604 r#"
605fn main() {
606 let arr = [-1, 1, 2, 3];
607 if !arr.into_iter().any(|num: i32| {
608 if num.is_positive() {
609 num > 3
610 } else {
611 num < -1
612 }
613 }) {
614 println!("foo");
615 }
616}
617"#,
618 );
619 }
620
621 #[test]
622 fn demorgan_iterator_wrong_method() {
623 check_assist_not_applicable(
624 apply_demorgan_iterator,
625 r#"
626//- minicore: iterator
627fn main() {
628 let arr = [1, 2, 3];
629 if !arr.into_iter().$0map(|num| num > 3) {
630 println!("foo");
631 }
632}
633"#,
634 );
635 }
636
637 #[test]
638 fn demorgan_method_call_receiver() {
639 check_assist(
640 apply_demorgan,
641 "fn f() { (x ||$0 !y).then_some(42) }",
642 "fn f() { (!(!x && y)).then_some(42) }",
643 );
644 }
645
646 #[test]
647 fn demorgan_method_call_receiver_complex() {
648 check_assist(
649 apply_demorgan,
650 "fn f() { (a && b ||$0 c && d).then_some(42) }",
651 "fn f() { (!(!(a && b) && !(c && d))).then_some(42) }",
652 );
653 }
654
655 #[test]
656 fn demorgan_method_call_receiver_chained() {
657 check_assist(
658 apply_demorgan,
659 "fn f() { (a ||$0 b).then_some(42).or(Some(0)) }",
660 "fn f() { (!(!a && !b)).then_some(42).or(Some(0)) }",
661 );
662 }
663}