ide_assists/handlers/
apply_demorgan.rs

1use std::collections::VecDeque;
2
3use ide_db::{
4    assists::GroupLabel,
5    famous_defs::FamousDefs,
6    syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
7};
8use syntax::{
9    NodeOrToken, SyntaxKind, T,
10    ast::{
11        self, AstNode,
12        Expr::BinExpr,
13        HasArgList,
14        prec::{ExprPrecedence, precedence},
15        syntax_factory::SyntaxFactory,
16    },
17    syntax_editor::{Position, SyntaxEditor},
18};
19
20use crate::{AssistContext, AssistId, Assists, utils::invert_boolean_expression};
21
22// Assist: apply_demorgan
23//
24// Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws).
25// This transforms expressions of the form `!l || !r` into `!(l && r)`.
26// This also works with `&&`. This assist can only be applied with the cursor
27// on either `||` or `&&`.
28//
29// ```
30// fn main() {
31//     if x != 4 ||$0 y < 3.14 {}
32// }
33// ```
34// ->
35// ```
36// fn main() {
37//     if !(x == 4 && y >= 3.14) {}
38// }
39// ```
40pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
41    let mut bin_expr = if let Some(not) = ctx.find_token_syntax_at_offset(T![!])
42        && let Some(NodeOrToken::Node(next)) = not.next_sibling_or_token()
43        && let Some(paren) = ast::ParenExpr::cast(next)
44        && let Some(ast::Expr::BinExpr(bin_expr)) = paren.expr()
45    {
46        bin_expr
47    } else {
48        let bin_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
49        let op_range = bin_expr.op_token()?.text_range();
50
51        // Is the cursor on the expression's logical operator?
52        if !op_range.contains_range(ctx.selection_trimmed()) {
53            return None;
54        }
55
56        bin_expr
57    };
58
59    let op = bin_expr.op_kind()?;
60    let op_range = bin_expr.op_token()?.text_range();
61
62    // Walk up the tree while we have the same binary operator
63    while let Some(parent_expr) = bin_expr.syntax().parent().and_then(ast::BinExpr::cast) {
64        match parent_expr.op_kind() {
65            Some(parent_op) if parent_op == op => {
66                bin_expr = parent_expr;
67            }
68            _ => break,
69        }
70    }
71
72    let op = bin_expr.op_kind()?;
73    let (inv_token, prec) = match op {
74        ast::BinaryOp::LogicOp(ast::LogicOp::And) => (SyntaxKind::PIPE2, ExprPrecedence::LOr),
75        ast::BinaryOp::LogicOp(ast::LogicOp::Or) => (SyntaxKind::AMP2, ExprPrecedence::LAnd),
76        _ => return None,
77    };
78
79    let make = SyntaxFactory::with_mappings();
80
81    let demorganed = bin_expr.clone_subtree();
82    let mut editor = SyntaxEditor::new(demorganed.syntax().clone());
83    editor.replace(demorganed.op_token()?, make.token(inv_token));
84
85    let mut exprs = VecDeque::from([
86        (bin_expr.lhs()?, demorganed.lhs()?, prec),
87        (bin_expr.rhs()?, demorganed.rhs()?, prec),
88    ]);
89
90    while let Some((expr, demorganed, prec)) = exprs.pop_front() {
91        if let BinExpr(bin_expr) = &expr {
92            if let BinExpr(cbin_expr) = &demorganed {
93                if op == bin_expr.op_kind()? {
94                    editor.replace(cbin_expr.op_token()?, make.token(inv_token));
95                    exprs.push_back((bin_expr.lhs()?, cbin_expr.lhs()?, prec));
96                    exprs.push_back((bin_expr.rhs()?, cbin_expr.rhs()?, prec));
97                } else {
98                    let mut inv = invert_boolean_expression(&make, expr);
99                    if precedence(&inv).needs_parentheses_in(prec) {
100                        inv = make.expr_paren(inv).into();
101                    }
102                    editor.replace(demorganed.syntax(), inv.syntax());
103                }
104            } else {
105                return None;
106            }
107        } else {
108            let mut inv = invert_boolean_expression(&make, demorganed.clone());
109            if precedence(&inv).needs_parentheses_in(prec) {
110                inv = make.expr_paren(inv).into();
111            }
112            editor.replace(demorganed.syntax(), inv.syntax());
113        }
114    }
115
116    editor.add_mappings(make.finish_with_mappings());
117    let edit = editor.finish();
118    let demorganed = ast::Expr::cast(edit.new_root().clone())?;
119
120    acc.add_group(
121        &GroupLabel("Apply De Morgan's law".to_owned()),
122        AssistId::refactor_rewrite("apply_demorgan"),
123        "Apply De Morgan's law",
124        op_range,
125        |builder| {
126            let make = SyntaxFactory::with_mappings();
127            let (target_node, result_expr) = if let Some(neg_expr) = bin_expr
128                .syntax()
129                .parent()
130                .and_then(ast::ParenExpr::cast)
131                .and_then(|paren_expr| paren_expr.syntax().parent())
132                .and_then(ast::PrefixExpr::cast)
133                .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
134            {
135                cov_mark::hit!(demorgan_double_negation);
136                (ast::Expr::from(neg_expr).syntax().clone(), demorganed)
137            } else if let Some(paren_expr) =
138                bin_expr.syntax().parent().and_then(ast::ParenExpr::cast)
139            {
140                cov_mark::hit!(demorgan_double_parens);
141                (paren_expr.syntax().clone(), add_bang_paren(&make, demorganed))
142            } else {
143                (bin_expr.syntax().clone(), add_bang_paren(&make, demorganed))
144            };
145
146            let final_expr = if target_node
147                .parent()
148                .is_some_and(|p| result_expr.needs_parens_in_place_of(&p, &target_node))
149            {
150                cov_mark::hit!(demorgan_keep_parens_for_op_precedence2);
151                make.expr_paren(result_expr).into()
152            } else {
153                result_expr
154            };
155
156            let mut editor = builder.make_editor(&target_node);
157            editor.replace(&target_node, final_expr.syntax());
158            editor.add_mappings(make.finish_with_mappings());
159            builder.add_file_edits(ctx.vfs_file_id(), editor);
160        },
161    )
162}
163
164// Assist: apply_demorgan_iterator
165//
166// Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws) to
167// `Iterator::all` and `Iterator::any`.
168//
169// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
170// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
171// `Iterator::all` into `Iterator::any`.
172//
173// ```
174// # //- minicore: iterator
175// fn main() {
176//     let arr = [1, 2, 3];
177//     if !arr.into_iter().$0any(|num| num == 4) {
178//         println!("foo");
179//     }
180// }
181// ```
182// ->
183// ```
184// fn main() {
185//     let arr = [1, 2, 3];
186//     if arr.into_iter().all(|num| num != 4) {
187//         println!("foo");
188//     }
189// }
190// ```
191pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
192    let method_call: ast::MethodCallExpr = ctx.find_node_at_offset()?;
193    let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
194
195    let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
196    let closure_body = closure_expr.body()?.clone_for_update();
197
198    let op_range = method_call.syntax().text_range();
199    let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
200    acc.add_group(
201        &GroupLabel("Apply De Morgan's law".to_owned()),
202        AssistId::refactor_rewrite("apply_demorgan_iterator"),
203        label,
204        op_range,
205        |builder| {
206            let make = SyntaxFactory::with_mappings();
207            let mut editor = builder.make_editor(method_call.syntax());
208            // replace the method name
209            let new_name = match name.text().as_str() {
210                "all" => make.name_ref("any"),
211                "any" => make.name_ref("all"),
212                _ => unreachable!(),
213            };
214            editor.replace(name.syntax(), new_name.syntax());
215
216            // negate all tail expressions in the closure body
217            let tail_cb = &mut |e: &_| tail_cb_impl(&mut editor, &make, e);
218            walk_expr(&closure_body, &mut |expr| {
219                if let ast::Expr::ReturnExpr(ret_expr) = expr
220                    && let Some(ret_expr_arg) = &ret_expr.expr()
221                {
222                    for_each_tail_expr(ret_expr_arg, tail_cb);
223                }
224            });
225            for_each_tail_expr(&closure_body, tail_cb);
226
227            // negate the whole method call
228            if let Some(prefix_expr) = method_call
229                .syntax()
230                .parent()
231                .and_then(ast::PrefixExpr::cast)
232                .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
233            {
234                editor.delete(
235                    prefix_expr.op_token().expect("prefix expression always has an operator"),
236                );
237            } else {
238                editor.insert(Position::before(method_call.syntax()), make.token(SyntaxKind::BANG));
239            }
240
241            editor.add_mappings(make.finish_with_mappings());
242            builder.add_file_edits(ctx.vfs_file_id(), editor);
243        },
244    )
245}
246
247/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
248fn validate_method_call_expr(
249    ctx: &AssistContext<'_>,
250    method_call: &ast::MethodCallExpr,
251) -> Option<(ast::NameRef, ast::Expr)> {
252    let name_ref = method_call.name_ref()?;
253    if name_ref.text() != "all" && name_ref.text() != "any" {
254        return None;
255    }
256    let arg_expr = method_call.arg_list()?.args().next()?;
257
258    let sema = &ctx.sema;
259
260    let receiver = method_call.receiver()?;
261    let it_type = sema.type_of_expr(&receiver)?.adjusted();
262    let module = sema.scope(receiver.syntax())?.module();
263    let krate = module.krate(ctx.db());
264
265    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
266    it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
267}
268
269fn tail_cb_impl(editor: &mut SyntaxEditor, make: &SyntaxFactory, e: &ast::Expr) {
270    match e {
271        ast::Expr::BreakExpr(break_expr) => {
272            if let Some(break_expr_arg) = break_expr.expr() {
273                for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(editor, make, e))
274            }
275        }
276        ast::Expr::ReturnExpr(_) => {
277            // all return expressions have already been handled by the walk loop
278        }
279        e => {
280            let inverted_body = invert_boolean_expression(make, e.clone());
281            editor.replace(e.syntax(), inverted_body.syntax());
282        }
283    }
284}
285
286/// Add bang and parentheses to the expression.
287fn add_bang_paren(make: &SyntaxFactory, expr: ast::Expr) -> ast::Expr {
288    make.expr_prefix(T![!], make.expr_paren(expr).into()).into()
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::tests::{check_assist, check_assist_not_applicable};
295
296    #[test]
297    fn demorgan_handles_leq() {
298        check_assist(
299            apply_demorgan,
300            r#"
301struct S;
302fn f() { S < S &&$0 S <= S }
303"#,
304            r#"
305struct S;
306fn f() { !(S >= S || S > S) }
307"#,
308        );
309    }
310
311    #[test]
312    fn demorgan_handles_geq() {
313        check_assist(
314            apply_demorgan,
315            r#"
316struct S;
317fn f() { S > S &&$0 S >= S }
318"#,
319            r#"
320struct S;
321fn f() { !(S <= S || S < S) }
322"#,
323        );
324    }
325
326    #[test]
327    fn demorgan_turns_and_into_or() {
328        check_assist(apply_demorgan, "fn f() { !x &&$0 !x }", "fn f() { !(x || x) }")
329    }
330
331    #[test]
332    fn demorgan_turns_or_into_and() {
333        check_assist(apply_demorgan, "fn f() { !x ||$0 !x }", "fn f() { !(x && x) }")
334    }
335
336    #[test]
337    fn demorgan_removes_inequality() {
338        check_assist(apply_demorgan, "fn f() { x != x ||$0 !x }", "fn f() { !(x == x && x) }")
339    }
340
341    #[test]
342    fn demorgan_general_case() {
343        check_assist(apply_demorgan, "fn f() { x ||$0 x }", "fn f() { !(!x && !x) }")
344    }
345
346    #[test]
347    fn demorgan_multiple_terms() {
348        check_assist(apply_demorgan, "fn f() { x ||$0 y || z }", "fn f() { !(!x && !y && !z) }");
349        check_assist(apply_demorgan, "fn f() { x || y ||$0 z }", "fn f() { !(!x && !y && !z) }");
350    }
351
352    #[test]
353    fn demorgan_doesnt_apply_with_cursor_not_on_op() {
354        check_assist_not_applicable(apply_demorgan, "fn f() { $0 !x || !x }")
355    }
356
357    #[test]
358    fn demorgan_doesnt_double_negation() {
359        cov_mark::check!(demorgan_double_negation);
360        check_assist(apply_demorgan, "fn f() { !(x ||$0 x) }", "fn f() { !x && !x }")
361    }
362
363    #[test]
364    fn demorgan_doesnt_double_parens() {
365        cov_mark::check!(demorgan_double_parens);
366        check_assist(apply_demorgan, "fn f() { (x ||$0 x) }", "fn f() { !(!x && !x) }")
367    }
368
369    #[test]
370    fn demorgan_doesnt_hang() {
371        check_assist(
372            apply_demorgan,
373            "fn f() { 1 || 3 &&$0 4 || 5 }",
374            "fn f() { 1 || !(!3 || !4) || 5 }",
375        )
376    }
377
378    #[test]
379    fn demorgan_on_not() {
380        check_assist(
381            apply_demorgan,
382            "fn f() { $0!(1 || 3 && 4 || 5) }",
383            "fn f() { !1 && !(3 && 4) && !5 }",
384        )
385    }
386
387    #[test]
388    fn demorgan_keep_pars_for_op_precedence() {
389        check_assist(
390            apply_demorgan,
391            "fn main() {
392    let _ = !(!a ||$0 !(b || c));
393}
394",
395            "fn main() {
396    let _ = a && (b || c);
397}
398",
399        );
400    }
401
402    #[test]
403    fn demorgan_keep_pars_for_op_precedence2() {
404        cov_mark::check!(demorgan_keep_parens_for_op_precedence2);
405        check_assist(
406            apply_demorgan,
407            "fn f() { (a && !(b &&$0 c); }",
408            "fn f() { (a && (!b || !c); }",
409        );
410    }
411
412    #[test]
413    fn demorgan_keep_pars_for_op_precedence3() {
414        check_assist(
415            apply_demorgan,
416            "fn f() { (a || !(b &&$0 c); }",
417            "fn f() { (a || (!b || !c); }",
418        );
419    }
420
421    #[test]
422    fn demorgan_keeps_pars_in_eq_precedence() {
423        check_assist(
424            apply_demorgan,
425            "fn() { let x = a && !(!b |$0| !c); }",
426            "fn() { let x = a && (b && c); }",
427        )
428    }
429
430    #[test]
431    fn demorgan_removes_pars_for_op_precedence2() {
432        check_assist(apply_demorgan, "fn f() { (a || !(b ||$0 c); }", "fn f() { (a || !b && !c; }");
433    }
434
435    #[test]
436    fn demorgan_iterator_any_all_reverse() {
437        check_assist(
438            apply_demorgan_iterator,
439            r#"
440//- minicore: iterator
441fn main() {
442    let arr = [1, 2, 3];
443    if arr.into_iter().all(|num| num $0!= 4) {
444        println!("foo");
445    }
446}
447"#,
448            r#"
449fn main() {
450    let arr = [1, 2, 3];
451    if !arr.into_iter().any(|num| num == 4) {
452        println!("foo");
453    }
454}
455"#,
456        );
457    }
458
459    #[test]
460    fn demorgan_iterator_all_any() {
461        check_assist(
462            apply_demorgan_iterator,
463            r#"
464//- minicore: iterator
465fn main() {
466    let arr = [1, 2, 3];
467    if !arr.into_iter().$0all(|num| num > 3) {
468        println!("foo");
469    }
470}
471"#,
472            r#"
473fn main() {
474    let arr = [1, 2, 3];
475    if arr.into_iter().any(|num| num <= 3) {
476        println!("foo");
477    }
478}
479"#,
480        );
481    }
482
483    #[test]
484    fn demorgan_iterator_multiple_terms() {
485        check_assist(
486            apply_demorgan_iterator,
487            r#"
488//- minicore: iterator
489fn main() {
490    let arr = [1, 2, 3];
491    if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
492        println!("foo");
493    }
494}
495"#,
496            r#"
497fn main() {
498    let arr = [1, 2, 3];
499    if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
500        println!("foo");
501    }
502}
503"#,
504        );
505    }
506
507    #[test]
508    fn demorgan_iterator_double_negation() {
509        check_assist(
510            apply_demorgan_iterator,
511            r#"
512//- minicore: iterator
513fn main() {
514    let arr = [1, 2, 3];
515    if !arr.into_iter().$0all(|num| !(num > 3)) {
516        println!("foo");
517    }
518}
519"#,
520            r#"
521fn main() {
522    let arr = [1, 2, 3];
523    if arr.into_iter().any(|num| num > 3) {
524        println!("foo");
525    }
526}
527"#,
528        );
529    }
530
531    #[test]
532    fn demorgan_iterator_double_parens() {
533        check_assist(
534            apply_demorgan_iterator,
535            r#"
536//- minicore: iterator
537fn main() {
538    let arr = [1, 2, 3];
539    if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
540        println!("foo");
541    }
542}
543"#,
544            r#"
545fn main() {
546    let arr = [1, 2, 3];
547    if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
548        println!("foo");
549    }
550}
551"#,
552        );
553    }
554
555    #[test]
556    fn demorgan_iterator_multiline() {
557        check_assist(
558            apply_demorgan_iterator,
559            r#"
560//- minicore: iterator
561fn main() {
562    let arr = [1, 2, 3];
563    if arr
564        .into_iter()
565        .all$0(|num| !num.is_negative())
566    {
567        println!("foo");
568    }
569}
570"#,
571            r#"
572fn main() {
573    let arr = [1, 2, 3];
574    if !arr
575        .into_iter()
576        .any(|num| num.is_negative())
577    {
578        println!("foo");
579    }
580}
581"#,
582        );
583    }
584
585    #[test]
586    fn demorgan_iterator_block_closure() {
587        check_assist(
588            apply_demorgan_iterator,
589            r#"
590//- minicore: iterator
591fn main() {
592    let arr = [-1, 1, 2, 3];
593    if arr.into_iter().all(|num: i32| {
594        $0if num.is_positive() {
595            num <= 3
596        } else {
597            num >= -1
598        }
599    }) {
600        println!("foo");
601    }
602}
603"#,
604            r#"
605fn main() {
606    let arr = [-1, 1, 2, 3];
607    if !arr.into_iter().any(|num: i32| {
608        if num.is_positive() {
609            num > 3
610        } else {
611            num < -1
612        }
613    }) {
614        println!("foo");
615    }
616}
617"#,
618        );
619    }
620
621    #[test]
622    fn demorgan_iterator_wrong_method() {
623        check_assist_not_applicable(
624            apply_demorgan_iterator,
625            r#"
626//- minicore: iterator
627fn main() {
628    let arr = [1, 2, 3];
629    if !arr.into_iter().$0map(|num| num > 3) {
630        println!("foo");
631    }
632}
633"#,
634        );
635    }
636
637    #[test]
638    fn demorgan_method_call_receiver() {
639        check_assist(
640            apply_demorgan,
641            "fn f() { (x ||$0 !y).then_some(42) }",
642            "fn f() { (!(!x && y)).then_some(42) }",
643        );
644    }
645
646    #[test]
647    fn demorgan_method_call_receiver_complex() {
648        check_assist(
649            apply_demorgan,
650            "fn f() { (a && b ||$0 c && d).then_some(42) }",
651            "fn f() { (!(!(a && b) && !(c && d))).then_some(42) }",
652        );
653    }
654
655    #[test]
656    fn demorgan_method_call_receiver_chained() {
657        check_assist(
658            apply_demorgan,
659            "fn f() { (a ||$0 b).then_some(42).or(Some(0)) }",
660            "fn f() { (!(!a && !b)).then_some(42).or(Some(0)) }",
661        );
662    }
663}