Skip to main content

ide_assists/handlers/
apply_demorgan.rs

1use std::collections::VecDeque;
2
3use ide_db::{
4    assists::GroupLabel,
5    famous_defs::FamousDefs,
6    syntax_helpers::node_ext::{for_each_tail_expr, is_pattern_cond, walk_expr},
7};
8use syntax::{
9    NodeOrToken, SyntaxKind, T,
10    ast::{
11        self, AstNode,
12        Expr::BinExpr,
13        HasArgList,
14        prec::{ExprPrecedence, precedence},
15        syntax_factory::SyntaxFactory,
16    },
17    syntax_editor::{Position, SyntaxEditor},
18};
19
20use crate::{AssistContext, AssistId, Assists, utils::invert_boolean_expression};
21
22// Assist: apply_demorgan
23//
24// Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws).
25// This transforms expressions of the form `!l || !r` into `!(l && r)`.
26// This also works with `&&`. This assist can only be applied with the cursor
27// on either `||` or `&&`.
28//
29// ```
30// fn main() {
31//     if x != 4 ||$0 y < 3.14 {}
32// }
33// ```
34// ->
35// ```
36// fn main() {
37//     if !(x == 4 && y >= 3.14) {}
38// }
39// ```
40pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_, '_>) -> Option<()> {
41    let mut bin_expr = if let Some(not) = ctx.find_token_syntax_at_offset(T![!])
42        && let Some(NodeOrToken::Node(next)) = not.next_sibling_or_token()
43        && let Some(paren) = ast::ParenExpr::cast(next)
44        && let Some(ast::Expr::BinExpr(bin_expr)) = paren.expr()
45    {
46        bin_expr
47    } else {
48        let bin_expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
49        let op_range = bin_expr.op_token()?.text_range();
50
51        // Is the cursor on the expression's logical operator?
52        if !op_range.contains_range(ctx.selection_trimmed()) {
53            return None;
54        }
55
56        bin_expr
57    };
58
59    let op = bin_expr.op_kind()?;
60    let op_range = bin_expr.op_token()?.text_range();
61
62    // Walk up the tree while we have the same binary operator
63    while let Some(parent_expr) = bin_expr.syntax().parent().and_then(ast::BinExpr::cast) {
64        match parent_expr.op_kind() {
65            Some(parent_op) if parent_op == op => {
66                bin_expr = parent_expr;
67            }
68            _ => break,
69        }
70    }
71
72    if is_pattern_cond(bin_expr.clone().into()) {
73        return None;
74    }
75
76    let op = bin_expr.op_kind()?;
77    let (inv_token, prec) = match op {
78        ast::BinaryOp::LogicOp(ast::LogicOp::And) => (SyntaxKind::PIPE2, ExprPrecedence::LOr),
79        ast::BinaryOp::LogicOp(ast::LogicOp::Or) => (SyntaxKind::AMP2, ExprPrecedence::LAnd),
80        _ => return None,
81    };
82
83    let (editor, demorganed) = SyntaxEditor::with_ast_node(&bin_expr);
84    let make = editor.make();
85    editor.replace(demorganed.op_token()?, make.token(inv_token));
86
87    let mut exprs = VecDeque::from([
88        (bin_expr.lhs()?, demorganed.lhs()?, prec),
89        (bin_expr.rhs()?, demorganed.rhs()?, prec),
90    ]);
91
92    while let Some((expr, demorganed, prec)) = exprs.pop_front() {
93        if let BinExpr(bin_expr) = &expr {
94            if let BinExpr(cbin_expr) = &demorganed {
95                if op == bin_expr.op_kind()? {
96                    editor.replace(cbin_expr.op_token()?, make.token(inv_token));
97                    exprs.push_back((bin_expr.lhs()?, cbin_expr.lhs()?, prec));
98                    exprs.push_back((bin_expr.rhs()?, cbin_expr.rhs()?, prec));
99                } else {
100                    let mut inv = invert_boolean_expression(make, expr);
101                    if precedence(&inv).needs_parentheses_in(prec) {
102                        inv = make.expr_paren(inv).into();
103                    }
104                    editor.replace(demorganed.syntax(), inv.syntax());
105                }
106            } else {
107                return None;
108            }
109        } else {
110            let mut inv = invert_boolean_expression(make, demorganed.clone());
111            if precedence(&inv).needs_parentheses_in(prec) {
112                inv = make.expr_paren(inv).into();
113            }
114            editor.replace(demorganed.syntax(), inv.syntax());
115        }
116    }
117
118    let edit = editor.finish();
119    let demorganed = ast::Expr::cast(edit.new_root().clone())?;
120
121    acc.add_group(
122        &GroupLabel("Apply De Morgan's law".to_owned()),
123        AssistId::refactor_rewrite("apply_demorgan"),
124        "Apply De Morgan's law",
125        op_range,
126        |builder| {
127            let editor = builder.make_editor(bin_expr.syntax());
128            let make = editor.make();
129
130            let (target_node, result_expr) = if let Some(neg_expr) = bin_expr
131                .syntax()
132                .parent()
133                .and_then(ast::ParenExpr::cast)
134                .and_then(|paren_expr| paren_expr.syntax().parent())
135                .and_then(ast::PrefixExpr::cast)
136                .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
137            {
138                cov_mark::hit!(demorgan_double_negation);
139                (ast::Expr::from(neg_expr).syntax().clone(), demorganed)
140            } else if let Some(paren_expr) =
141                bin_expr.syntax().parent().and_then(ast::ParenExpr::cast)
142            {
143                cov_mark::hit!(demorgan_double_parens);
144                (paren_expr.syntax().clone(), add_bang_paren(make, demorganed))
145            } else {
146                (bin_expr.syntax().clone(), add_bang_paren(make, demorganed))
147            };
148
149            let final_expr = if target_node
150                .parent()
151                .is_some_and(|p| result_expr.needs_parens_in_place_of(&p, &target_node))
152            {
153                cov_mark::hit!(demorgan_keep_parens_for_op_precedence2);
154                make.expr_paren(result_expr).into()
155            } else {
156                result_expr
157            };
158
159            editor.replace(&target_node, final_expr.syntax());
160            builder.add_file_edits(ctx.vfs_file_id(), editor);
161        },
162    )
163}
164
165// Assist: apply_demorgan_iterator
166//
167// Apply [De Morgan's law](https://en.wikipedia.org/wiki/De_Morgan%27s_laws) to
168// `Iterator::all` and `Iterator::any`.
169//
170// This transforms expressions of the form `!iter.any(|x| predicate(x))` into
171// `iter.all(|x| !predicate(x))` and vice versa. This also works the other way for
172// `Iterator::all` into `Iterator::any`.
173//
174// ```
175// # //- minicore: iterator
176// fn main() {
177//     let arr = [1, 2, 3];
178//     if !arr.into_iter().$0any(|num| num == 4) {
179//         println!("foo");
180//     }
181// }
182// ```
183// ->
184// ```
185// fn main() {
186//     let arr = [1, 2, 3];
187//     if arr.into_iter().all(|num| num != 4) {
188//         println!("foo");
189//     }
190// }
191// ```
192pub(crate) fn apply_demorgan_iterator(
193    acc: &mut Assists,
194    ctx: &AssistContext<'_, '_>,
195) -> Option<()> {
196    let method_call: ast::MethodCallExpr = ctx.find_node_at_offset().or_else(|| {
197        let parent = ctx.find_token_syntax_at_offset(T![!])?.parent()?;
198        match ast::PrefixExpr::cast(parent)?.expr()? {
199            ast::Expr::MethodCallExpr(method_call) => Some(method_call),
200            _ => None,
201        }
202    })?;
203    let (name, arg_expr) = validate_method_call_expr(ctx, &method_call)?;
204
205    let ast::Expr::ClosureExpr(closure_expr) = arg_expr else { return None };
206    let closure_body = closure_expr.body()?;
207
208    let op_range = method_call.syntax().text_range();
209    let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str());
210    acc.add_group(
211        &GroupLabel("Apply De Morgan's law".to_owned()),
212        AssistId::refactor_rewrite("apply_demorgan_iterator"),
213        label,
214        op_range,
215        |builder| {
216            let editor = builder.make_editor(method_call.syntax());
217            let make = editor.make();
218            // replace the method name
219            let new_name = match name.text().as_str() {
220                "all" => make.name_ref("any"),
221                "any" => make.name_ref("all"),
222                "is_some_and" => make.name_ref("is_none_or"),
223                "is_none_or" => make.name_ref("is_some_and"),
224                _ => unreachable!(),
225            };
226            editor.replace(name.syntax(), new_name.syntax());
227
228            // negate all tail expressions in the closure body
229            let tail_cb = &mut |e: &_| tail_cb_impl(&editor, e);
230            walk_expr(&closure_body, &mut |expr| {
231                if let ast::Expr::ReturnExpr(ret_expr) = expr
232                    && let Some(ret_expr_arg) = &ret_expr.expr()
233                {
234                    for_each_tail_expr(ret_expr_arg, tail_cb);
235                }
236            });
237            for_each_tail_expr(&closure_body, tail_cb);
238
239            // negate the whole method call
240            if let Some(prefix_expr) = method_call
241                .syntax()
242                .parent()
243                .and_then(ast::PrefixExpr::cast)
244                .filter(|prefix_expr| matches!(prefix_expr.op_kind(), Some(ast::UnaryOp::Not)))
245            {
246                editor.delete(
247                    prefix_expr.op_token().expect("prefix expression always has an operator"),
248                );
249            } else {
250                editor.insert(Position::before(method_call.syntax()), make.token(SyntaxKind::BANG));
251            }
252            builder.add_file_edits(ctx.vfs_file_id(), editor);
253        },
254    )
255}
256
257/// Ensures that the method call is to `Iterator::all` or `Iterator::any`.
258fn validate_method_call_expr(
259    ctx: &AssistContext<'_, '_>,
260    method_call: &ast::MethodCallExpr,
261) -> Option<(ast::NameRef, ast::Expr)> {
262    let name_ref = method_call.name_ref()?;
263    let arg_expr = method_call.arg_list()?.args().next()?;
264    if name_ref.text() == "is_some_and" || name_ref.text() == "is_none_or" {
265        return Some((name_ref, arg_expr));
266    }
267    if name_ref.text() != "all" && name_ref.text() != "any" {
268        return None;
269    }
270
271    let sema = &ctx.sema;
272
273    let receiver = method_call.receiver()?;
274    let it_type = sema.type_of_expr(&receiver)?.adjusted();
275    let module = sema.scope(receiver.syntax())?.module();
276    let krate = module.krate(ctx.db());
277
278    let iter_trait = FamousDefs(sema, krate).core_iter_Iterator()?;
279    it_type.impls_trait(sema.db, iter_trait, &[]).then_some((name_ref, arg_expr))
280}
281
282fn tail_cb_impl(editor: &SyntaxEditor, e: &ast::Expr) {
283    match e {
284        ast::Expr::BreakExpr(break_expr) => {
285            if let Some(break_expr_arg) = break_expr.expr() {
286                for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(editor, e))
287            }
288        }
289        ast::Expr::ReturnExpr(_) => {
290            // all return expressions have already been handled by the walk loop
291        }
292        e => {
293            let inverted_body = invert_boolean_expression(editor.make(), e.clone());
294            editor.replace(e.syntax(), inverted_body.syntax());
295        }
296    }
297}
298
299/// Add bang and parentheses to the expression.
300fn add_bang_paren(make: &SyntaxFactory, expr: ast::Expr) -> ast::Expr {
301    make.expr_prefix(T![!], make.expr_paren(expr).into()).into()
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::tests::{check_assist, check_assist_not_applicable};
308
309    #[test]
310    fn demorgan_handles_leq() {
311        check_assist(
312            apply_demorgan,
313            r#"
314struct S;
315fn f() { S < S &&$0 S <= S }
316"#,
317            r#"
318struct S;
319fn f() { !(S >= S || S > S) }
320"#,
321        );
322    }
323
324    #[test]
325    fn demorgan_handles_geq() {
326        check_assist(
327            apply_demorgan,
328            r#"
329struct S;
330fn f() { S > S &&$0 S >= S }
331"#,
332            r#"
333struct S;
334fn f() { !(S <= S || S < S) }
335"#,
336        );
337    }
338
339    #[test]
340    fn demorgan_turns_and_into_or() {
341        check_assist(apply_demorgan, "fn f() { !x &&$0 !x }", "fn f() { !(x || x) }")
342    }
343
344    #[test]
345    fn demorgan_turns_or_into_and() {
346        check_assist(apply_demorgan, "fn f() { !x ||$0 !x }", "fn f() { !(x && x) }")
347    }
348
349    #[test]
350    fn demorgan_removes_inequality() {
351        check_assist(apply_demorgan, "fn f() { x != x ||$0 !x }", "fn f() { !(x == x && x) }")
352    }
353
354    #[test]
355    fn demorgan_general_case() {
356        check_assist(apply_demorgan, "fn f() { x ||$0 x }", "fn f() { !(!x && !x) }")
357    }
358
359    #[test]
360    fn demorgan_multiple_terms() {
361        check_assist(apply_demorgan, "fn f() { x ||$0 y || z }", "fn f() { !(!x && !y && !z) }");
362        check_assist(apply_demorgan, "fn f() { x || y ||$0 z }", "fn f() { !(!x && !y && !z) }");
363    }
364
365    #[test]
366    fn demorgan_doesnt_apply_with_cursor_not_on_op() {
367        check_assist_not_applicable(apply_demorgan, "fn f() { $0 !x || !x }")
368    }
369
370    #[test]
371    fn demorgan_doesnt_double_negation() {
372        cov_mark::check!(demorgan_double_negation);
373        check_assist(apply_demorgan, "fn f() { !(x ||$0 x) }", "fn f() { !x && !x }")
374    }
375
376    #[test]
377    fn demorgan_doesnt_double_parens() {
378        cov_mark::check!(demorgan_double_parens);
379        check_assist(apply_demorgan, "fn f() { (x ||$0 x) }", "fn f() { !(!x && !x) }")
380    }
381
382    #[test]
383    fn demorgan_doesnt_hang() {
384        check_assist(
385            apply_demorgan,
386            "fn f() { 1 || 3 &&$0 4 || 5 }",
387            "fn f() { 1 || !(!3 || !4) || 5 }",
388        )
389    }
390
391    #[test]
392    fn demorgan_doesnt_handles_pattern() {
393        check_assist_not_applicable(
394            apply_demorgan,
395            r#"
396fn f() { if let 1 = 1 &&$0 true { } }
397"#,
398        );
399    }
400
401    #[test]
402    fn demorgan_on_not() {
403        check_assist(
404            apply_demorgan,
405            "fn f() { $0!(1 || 3 && 4 || 5) }",
406            "fn f() { !1 && !(3 && 4) && !5 }",
407        )
408    }
409
410    #[test]
411    fn demorgan_iterator_on_not() {
412        check_assist(
413            apply_demorgan_iterator,
414            r#"
415//- minicore: iterator
416fn main() {
417    let arr = [1, 2, 3];
418    let cond = $0!arr.into_iter().all(|num| num != 4);
419}
420"#,
421            r#"
422fn main() {
423    let arr = [1, 2, 3];
424    let cond = arr.into_iter().any(|num| num == 4);
425}
426"#,
427        );
428    }
429
430    #[test]
431    fn demorgan_keep_pars_for_op_precedence() {
432        check_assist(
433            apply_demorgan,
434            "fn main() {
435    let _ = !(!a ||$0 !(b || c));
436}
437",
438            "fn main() {
439    let _ = a && (b || c);
440}
441",
442        );
443    }
444
445    #[test]
446    fn demorgan_keep_pars_for_op_precedence2() {
447        cov_mark::check!(demorgan_keep_parens_for_op_precedence2);
448        check_assist(
449            apply_demorgan,
450            "fn f() { (a && !(b &&$0 c); }",
451            "fn f() { (a && (!b || !c); }",
452        );
453    }
454
455    #[test]
456    fn demorgan_keep_pars_for_op_precedence3() {
457        check_assist(
458            apply_demorgan,
459            "fn f() { (a || !(b &&$0 c); }",
460            "fn f() { (a || (!b || !c); }",
461        );
462    }
463
464    #[test]
465    fn demorgan_keeps_pars_in_eq_precedence() {
466        check_assist(
467            apply_demorgan,
468            "fn() { let x = a && !(!b |$0| !c); }",
469            "fn() { let x = a && (b && c); }",
470        )
471    }
472
473    #[test]
474    fn demorgan_removes_pars_for_op_precedence2() {
475        check_assist(apply_demorgan, "fn f() { (a || !(b ||$0 c); }", "fn f() { (a || !b && !c; }");
476    }
477
478    #[test]
479    fn demorgan_iterator_any_all_reverse() {
480        check_assist(
481            apply_demorgan_iterator,
482            r#"
483//- minicore: iterator
484fn main() {
485    let arr = [1, 2, 3];
486    if arr.into_iter().all(|num| num $0!= 4) {
487        println!("foo");
488    }
489}
490"#,
491            r#"
492fn main() {
493    let arr = [1, 2, 3];
494    if !arr.into_iter().any(|num| num == 4) {
495        println!("foo");
496    }
497}
498"#,
499        );
500    }
501
502    #[test]
503    fn demorgan_iterator_all_any() {
504        check_assist(
505            apply_demorgan_iterator,
506            r#"
507//- minicore: iterator
508fn main() {
509    let arr = [1, 2, 3];
510    if !arr.into_iter().$0all(|num| num > 3) {
511        println!("foo");
512    }
513}
514"#,
515            r#"
516fn main() {
517    let arr = [1, 2, 3];
518    if arr.into_iter().any(|num| num <= 3) {
519        println!("foo");
520    }
521}
522"#,
523        );
524    }
525
526    #[test]
527    fn demorgan_iterator_multiple_terms() {
528        check_assist(
529            apply_demorgan_iterator,
530            r#"
531//- minicore: iterator
532fn main() {
533    let arr = [1, 2, 3];
534    if !arr.into_iter().$0any(|num| num > 3 && num == 23 && num <= 30) {
535        println!("foo");
536    }
537}
538"#,
539            r#"
540fn main() {
541    let arr = [1, 2, 3];
542    if arr.into_iter().all(|num| !(num > 3 && num == 23 && num <= 30)) {
543        println!("foo");
544    }
545}
546"#,
547        );
548    }
549
550    #[test]
551    fn demorgan_iterator_double_negation() {
552        check_assist(
553            apply_demorgan_iterator,
554            r#"
555//- minicore: iterator
556fn main() {
557    let arr = [1, 2, 3];
558    if !arr.into_iter().$0all(|num| !(num > 3)) {
559        println!("foo");
560    }
561}
562"#,
563            r#"
564fn main() {
565    let arr = [1, 2, 3];
566    if arr.into_iter().any(|num| num > 3) {
567        println!("foo");
568    }
569}
570"#,
571        );
572    }
573
574    #[test]
575    fn demorgan_iterator_double_parens() {
576        check_assist(
577            apply_demorgan_iterator,
578            r#"
579//- minicore: iterator
580fn main() {
581    let arr = [1, 2, 3];
582    if !arr.into_iter().$0any(|num| (num > 3 && (num == 1 || num == 2))) {
583        println!("foo");
584    }
585}
586"#,
587            r#"
588fn main() {
589    let arr = [1, 2, 3];
590    if arr.into_iter().all(|num| !(num > 3 && (num == 1 || num == 2))) {
591        println!("foo");
592    }
593}
594"#,
595        );
596    }
597
598    #[test]
599    fn demorgan_iterator_multiline() {
600        check_assist(
601            apply_demorgan_iterator,
602            r#"
603//- minicore: iterator
604fn main() {
605    let arr = [1, 2, 3];
606    if arr
607        .into_iter()
608        .all$0(|num| !num.is_negative())
609    {
610        println!("foo");
611    }
612}
613"#,
614            r#"
615fn main() {
616    let arr = [1, 2, 3];
617    if !arr
618        .into_iter()
619        .any(|num| num.is_negative())
620    {
621        println!("foo");
622    }
623}
624"#,
625        );
626    }
627
628    #[test]
629    fn demorgan_iterator_block_closure() {
630        check_assist(
631            apply_demorgan_iterator,
632            r#"
633//- minicore: iterator
634fn main() {
635    let arr = [-1, 1, 2, 3];
636    if arr.into_iter().all(|num: i32| {
637        $0if num.is_positive() {
638            num <= 3
639        } else {
640            num >= -1
641        }
642    }) {
643        println!("foo");
644    }
645}
646"#,
647            r#"
648fn main() {
649    let arr = [-1, 1, 2, 3];
650    if !arr.into_iter().any(|num: i32| {
651        if num.is_positive() {
652            num > 3
653        } else {
654            num < -1
655        }
656    }) {
657        println!("foo");
658    }
659}
660"#,
661        );
662    }
663
664    #[test]
665    fn demorgan_iterator_wrong_method() {
666        check_assist_not_applicable(
667            apply_demorgan_iterator,
668            r#"
669//- minicore: iterator
670fn main() {
671    let arr = [1, 2, 3];
672    if !arr.into_iter().$0map(|num| num > 3) {
673        println!("foo");
674    }
675}
676"#,
677        );
678    }
679
680    #[test]
681    fn demorgan_option_is_some_and() {
682        check_assist(
683            apply_demorgan_iterator,
684            r#"
685//- minicore: option
686fn main() {
687    let cond = Some(2);
688    if !cond.$0is_some_and(|num| num > 3) {
689        println!("foo");
690    }
691}
692"#,
693            r#"
694fn main() {
695    let cond = Some(2);
696    if cond.is_none_or(|num| num <= 3) {
697        println!("foo");
698    }
699}
700"#,
701        );
702
703        check_assist(
704            apply_demorgan_iterator,
705            r#"
706//- minicore: option
707fn main() {
708    let cond = Some(2);
709    if !cond.$0is_none_or(|num| num > 3) {
710        println!("foo");
711    }
712}
713"#,
714            r#"
715fn main() {
716    let cond = Some(2);
717    if cond.is_some_and(|num| num <= 3) {
718        println!("foo");
719    }
720}
721"#,
722        );
723    }
724
725    #[test]
726    fn demorgan_method_call_receiver() {
727        check_assist(
728            apply_demorgan,
729            "fn f() { (x ||$0 !y).then_some(42) }",
730            "fn f() { (!(!x && y)).then_some(42) }",
731        );
732    }
733
734    #[test]
735    fn demorgan_method_call_receiver_complex() {
736        check_assist(
737            apply_demorgan,
738            "fn f() { (a && b ||$0 c && d).then_some(42) }",
739            "fn f() { (!(!(a && b) && !(c && d))).then_some(42) }",
740        );
741    }
742
743    #[test]
744    fn demorgan_method_call_receiver_chained() {
745        check_assist(
746            apply_demorgan,
747            "fn f() { (a ||$0 b).then_some(42).or(Some(0)) }",
748            "fn f() { (!(!a && !b)).then_some(42).or(Some(0)) }",
749        );
750    }
751}