1use hir::{AsAssocItem, Semantics, sym};
2use ide_db::{
3 RootDatabase,
4 famous_defs::FamousDefs,
5 syntax_helpers::node_ext::{
6 block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
7 },
8};
9use itertools::Itertools;
10use syntax::{
11 AstNode, SyntaxNode,
12 ast::{self, HasArgList, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
13 syntax_editor::SyntaxEditor,
14};
15
16use crate::{
17 AssistContext, AssistId, Assists,
18 utils::{invert_boolean_expression, unwrap_trivial_block},
19};
20
21pub(crate) fn convert_if_to_bool_then(
42 acc: &mut Assists,
43 ctx: &AssistContext<'_, '_>,
44) -> Option<()> {
45 let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
47 if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
48 return None;
49 }
50
51 let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
52 let then = expr.then_branch()?;
53 let else_ = match expr.else_branch()? {
54 ast::ElseBranch::Block(b) => b,
55 ast::ElseBranch::IfExpr(_) => {
56 cov_mark::hit!(convert_if_to_bool_then_chain);
57 return None;
58 }
59 };
60
61 let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
62
63 let (invert_cond, closure_body) = match (
64 block_is_none_variant(&ctx.sema, &then, none_variant),
65 block_is_none_variant(&ctx.sema, &else_, none_variant),
66 ) {
67 (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
68 (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
69 _ => return None,
70 };
71
72 if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
73 cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
74 return None;
75 }
76
77 let target = expr.syntax().text_range();
78 acc.add(
79 AssistId::refactor_rewrite("convert_if_to_bool_then"),
80 "Convert `if` expression to `bool::then` call",
81 target,
82 |builder| {
83 let (editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body);
84 for_each_tail_expr(&closure_body, &mut |e| {
86 let e = match e {
87 ast::Expr::BreakExpr(e) => e.expr(),
88 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
89 _ => None,
90 };
91 if let Some(ast::Expr::CallExpr(call)) = e
92 && let Some(arg_list) = call.arg_list()
93 && let Some(arg) = arg_list.args().next()
94 {
95 editor.replace(call.syntax(), arg.syntax());
96 }
97 });
98 let edit = editor.finish();
99 let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap();
100
101 let editor = builder.make_editor(expr.syntax());
102 let make = editor.make();
103 let closure_body = match closure_body {
104 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
105 e => e,
106 };
107 let cond = if invert_cond { invert_boolean_expression(make, cond) } else { cond };
108
109 let parenthesize = matches!(
110 cond,
111 ast::Expr::BinExpr(_)
112 | ast::Expr::BlockExpr(_)
113 | ast::Expr::BreakExpr(_)
114 | ast::Expr::CastExpr(_)
115 | ast::Expr::ClosureExpr(_)
116 | ast::Expr::ContinueExpr(_)
117 | ast::Expr::ForExpr(_)
118 | ast::Expr::IfExpr(_)
119 | ast::Expr::LoopExpr(_)
120 | ast::Expr::MacroExpr(_)
121 | ast::Expr::MatchExpr(_)
122 | ast::Expr::PrefixExpr(_)
123 | ast::Expr::RangeExpr(_)
124 | ast::Expr::RefExpr(_)
125 | ast::Expr::ReturnExpr(_)
126 | ast::Expr::WhileExpr(_)
127 | ast::Expr::YieldExpr(_)
128 );
129
130 let cond = if parenthesize { make.expr_paren(cond).into() } else { cond };
131 let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into()));
132 let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list);
133 editor.replace(expr.syntax(), mcall.syntax());
134 builder.add_file_edits(ctx.vfs_file_id(), editor);
135 },
136 )
137}
138
139pub(crate) fn convert_bool_then_to_if(
160 acc: &mut Assists,
161 ctx: &AssistContext<'_, '_>,
162) -> Option<()> {
163 let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
164 let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
165 let receiver = mcall.receiver()?;
166 let closure_body = Itertools::exactly_one(mcall.arg_list()?.args()).ok()?;
168 let closure_body = match closure_body {
169 ast::Expr::ClosureExpr(expr) => expr.body()?,
170 _ => return None,
171 };
172 let func = ctx.sema.resolve_method_call(&mcall)?;
174 if func.name(ctx.sema.db) != sym::then {
175 return None;
176 }
177 let assoc = func.as_assoc_item(ctx.sema.db)?;
178 if !assoc.implementing_ty(ctx.sema.db)?.is_bool() {
179 return None;
180 }
181
182 let target = mcall.syntax().text_range();
183 acc.add(
184 AssistId::refactor_rewrite("convert_bool_then_to_if"),
185 "Convert `bool::then` call to `if`",
186 target,
187 |builder| {
188 let mapless_make = SyntaxFactory::without_mappings();
189 let closure_body = match closure_body.reset_indent() {
190 ast::Expr::BlockExpr(block) => block,
191 e => mapless_make.block_expr(None, Some(e)),
192 };
193
194 let (editor, closure_body) = SyntaxEditor::with_ast_node(&closure_body);
195 let none_path = mapless_make.expr_path(mapless_make.ident_path("None"));
197 let some_path = mapless_make.expr_path(mapless_make.ident_path("Some"));
198 for_each_tail_expr(&ast::Expr::BlockExpr(closure_body), &mut |e| {
199 let e = match e {
200 ast::Expr::BreakExpr(e) => e.expr(),
201 ast::Expr::ReturnExpr(e) => e.expr(),
202 _ => Some(e.clone()),
203 };
204 if let Some(expr) = e {
205 editor.replace(
206 expr.syntax().clone(),
207 mapless_make
208 .expr_call(some_path.clone(), mapless_make.arg_list(Some(expr)))
209 .syntax()
210 .clone(),
211 );
212 }
213 });
214 let edit = editor.finish();
215 let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap();
216
217 let editor = builder.make_editor(mcall.syntax());
218 let make = editor.make();
219
220 let cond = match &receiver {
221 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
222 _ => receiver,
223 };
224 let if_expr = make
225 .expr_if(
226 cond,
227 closure_body,
228 Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))),
229 )
230 .indent(mcall.indent_level());
231 editor.replace(mcall.syntax().clone(), if_expr.syntax().clone());
232 builder.add_file_edits(ctx.vfs_file_id(), editor);
233 },
234 )
235}
236
237fn option_variants(
238 sema: &Semantics<'_, RootDatabase>,
239 expr: &SyntaxNode,
240) -> Option<(hir::EnumVariant, hir::EnumVariant)> {
241 let fam = FamousDefs(sema, sema.scope(expr)?.krate());
242 let option_variants = fam.core_option_Option()?.variants(sema.db);
243 match &*option_variants {
244 &[variant0, variant1] => Some(if variant0.name(sema.db) == sym::None {
245 (variant0, variant1)
246 } else {
247 (variant1, variant0)
248 }),
249 _ => None,
250 }
251}
252
253fn is_invalid_body(
256 sema: &Semantics<'_, RootDatabase>,
257 some_variant: hir::EnumVariant,
258 expr: &ast::Expr,
259) -> bool {
260 let mut invalid = false;
261 preorder_expr(expr, &mut |e| {
262 invalid |=
263 matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
264 invalid
265 });
266 if !invalid {
267 for_each_tail_expr(expr, &mut |e| {
268 if invalid {
269 return;
270 }
271 let e = match e {
272 ast::Expr::BreakExpr(e) => e.expr(),
273 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
274 _ => None,
275 };
276 if let Some(ast::Expr::CallExpr(call)) = e
277 && let Some(ast::Expr::PathExpr(p)) = call.expr()
278 {
279 let res = p.path().and_then(|p| sema.resolve_path(&p));
280 if let Some(hir::PathResolution::Def(hir::ModuleDef::EnumVariant(v))) = res {
281 return invalid |= v != some_variant;
282 }
283 }
284 invalid = true
285 });
286 }
287 invalid
288}
289
290fn block_is_none_variant(
291 sema: &Semantics<'_, RootDatabase>,
292 block: &ast::BlockExpr,
293 none_variant: hir::EnumVariant,
294) -> bool {
295 block_as_lone_tail(block).and_then(|e| match e {
296 ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
297 hir::PathResolution::Def(hir::ModuleDef::EnumVariant(v)) => Some(v),
298 _ => None,
299 },
300 _ => None,
301 }) == Some(none_variant)
302}
303
304#[cfg(test)]
305mod tests {
306 use crate::tests::{check_assist, check_assist_not_applicable};
307
308 use super::*;
309
310 #[test]
311 fn convert_if_to_bool_then_simple() {
312 check_assist(
313 convert_if_to_bool_then,
314 r"
315//- minicore:option
316fn main() {
317 if$0 true {
318 Some(15)
319 } else {
320 None
321 }
322}
323",
324 r"
325fn main() {
326 true.then(|| 15)
327}
328",
329 );
330 }
331
332 #[test]
333 fn convert_if_to_bool_then_invert() {
334 check_assist(
335 convert_if_to_bool_then,
336 r"
337//- minicore:option
338fn main() {
339 if$0 true {
340 None
341 } else {
342 Some(15)
343 }
344}
345",
346 r"
347fn main() {
348 false.then(|| 15)
349}
350",
351 );
352 }
353
354 #[test]
355 fn convert_if_to_bool_then_none_none() {
356 check_assist_not_applicable(
357 convert_if_to_bool_then,
358 r"
359//- minicore:option
360fn main() {
361 if$0 true {
362 None
363 } else {
364 None
365 }
366}
367",
368 );
369 }
370
371 #[test]
372 fn convert_if_to_bool_then_some_some() {
373 check_assist_not_applicable(
374 convert_if_to_bool_then,
375 r"
376//- minicore:option
377fn main() {
378 if$0 true {
379 Some(15)
380 } else {
381 Some(15)
382 }
383}
384",
385 );
386 }
387
388 #[test]
389 fn convert_if_to_bool_then_mixed() {
390 check_assist_not_applicable(
391 convert_if_to_bool_then,
392 r"
393//- minicore:option
394fn main() {
395 if$0 true {
396 if true {
397 Some(15)
398 } else {
399 None
400 }
401 } else {
402 None
403 }
404}
405",
406 );
407 }
408
409 #[test]
410 fn convert_if_to_bool_then_chain() {
411 cov_mark::check!(convert_if_to_bool_then_chain);
412 check_assist_not_applicable(
413 convert_if_to_bool_then,
414 r"
415//- minicore:option
416fn main() {
417 if$0 true {
418 Some(15)
419 } else if true {
420 None
421 } else {
422 None
423 }
424}
425",
426 );
427 }
428
429 #[test]
430 fn convert_if_to_bool_then_pattern_cond() {
431 check_assist_not_applicable(
432 convert_if_to_bool_then,
433 r"
434//- minicore:option
435fn main() {
436 if$0 let true = true {
437 Some(15)
438 } else {
439 None
440 }
441}
442",
443 );
444 }
445
446 #[test]
447 fn convert_if_to_bool_then_pattern_invalid_body() {
448 cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
449 check_assist_not_applicable(
450 convert_if_to_bool_then,
451 r"
452//- minicore:option
453fn make_me_an_option() -> Option<i32> { None }
454fn main() {
455 if$0 true {
456 if true {
457 make_me_an_option()
458 } else {
459 Some(15)
460 }
461 } else {
462 None
463 }
464}
465",
466 );
467 check_assist_not_applicable(
468 convert_if_to_bool_then,
469 r"
470//- minicore:option
471fn main() {
472 if$0 true {
473 if true {
474 return;
475 }
476 Some(15)
477 } else {
478 None
479 }
480}
481",
482 );
483 }
484
485 #[test]
486 fn convert_bool_then_to_if_inapplicable() {
487 check_assist_not_applicable(
488 convert_bool_then_to_if,
489 r"
490//- minicore:bool_impl
491fn main() {
492 0.t$0hen(|| 15);
493}
494",
495 );
496 check_assist_not_applicable(
497 convert_bool_then_to_if,
498 r"
499//- minicore:bool_impl
500fn main() {
501 true.t$0hen(15);
502}
503",
504 );
505 check_assist_not_applicable(
506 convert_bool_then_to_if,
507 r"
508//- minicore:bool_impl
509fn main() {
510 true.t$0hen(|| 15, 15);
511}
512",
513 );
514 }
515
516 #[test]
517 fn convert_bool_then_to_if_simple() {
518 check_assist(
519 convert_bool_then_to_if,
520 r"
521//- minicore:bool_impl
522fn main() {
523 true.t$0hen(|| 15)
524}
525",
526 r"
527fn main() {
528 if true {
529 Some(15)
530 } else {
531 None
532 }
533}
534",
535 );
536 check_assist(
537 convert_bool_then_to_if,
538 r"
539//- minicore:bool_impl
540fn main() {
541 true.t$0hen(|| {
542 15
543 })
544}
545",
546 r"
547fn main() {
548 if true {
549 Some(15)
550 } else {
551 None
552 }
553}
554",
555 );
556 }
557
558 #[test]
559 fn convert_bool_then_to_if_tails() {
560 check_assist(
561 convert_bool_then_to_if,
562 r"
563//- minicore:bool_impl
564fn main() {
565 true.t$0hen(|| {
566 loop {
567 if false {
568 break 0;
569 }
570 break 15;
571 }
572 })
573}
574",
575 r"
576fn main() {
577 if true {
578 loop {
579 if false {
580 break Some(0);
581 }
582 break Some(15);
583 }
584 } else {
585 None
586 }
587}
588",
589 );
590 }
591 #[test]
592 fn convert_if_to_bool_then_invert_method_call() {
593 check_assist(
594 convert_if_to_bool_then,
595 r"
596//- minicore:option
597fn main() {
598 let test = &[()];
599 let value = if$0 test.is_empty() { None } else { Some(()) };
600}
601",
602 r"
603fn main() {
604 let test = &[()];
605 let value = (!test.is_empty()).then(|| ());
606}
607",
608 );
609 }
610}