1use hir::{AsAssocItem, Semantics, sym};
2use ide_db::{
3 RootDatabase,
4 famous_defs::FamousDefs,
5 syntax_helpers::node_ext::{
6 block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
7 },
8};
9use itertools::Itertools;
10use syntax::{
11 AstNode, SyntaxNode,
12 ast::{self, HasArgList, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
13 syntax_editor::SyntaxEditor,
14};
15
16use crate::{
17 AssistContext, AssistId, Assists,
18 utils::{invert_boolean_expression, unwrap_trivial_block},
19};
20
21pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
42 let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
44 if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
45 return None;
46 }
47
48 let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
49 let then = expr.then_branch()?;
50 let else_ = match expr.else_branch()? {
51 ast::ElseBranch::Block(b) => b,
52 ast::ElseBranch::IfExpr(_) => {
53 cov_mark::hit!(convert_if_to_bool_then_chain);
54 return None;
55 }
56 };
57
58 let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
59
60 let (invert_cond, closure_body) = match (
61 block_is_none_variant(&ctx.sema, &then, none_variant),
62 block_is_none_variant(&ctx.sema, &else_, none_variant),
63 ) {
64 (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
65 (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
66 _ => return None,
67 };
68
69 if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
70 cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
71 return None;
72 }
73
74 let target = expr.syntax().text_range();
75 acc.add(
76 AssistId::refactor_rewrite("convert_if_to_bool_then"),
77 "Convert `if` expression to `bool::then` call",
78 target,
79 |builder| {
80 let closure_body = closure_body.clone_subtree();
81 let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
82 for_each_tail_expr(&closure_body, &mut |e| {
84 let e = match e {
85 ast::Expr::BreakExpr(e) => e.expr(),
86 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
87 _ => None,
88 };
89 if let Some(ast::Expr::CallExpr(call)) = e
90 && let Some(arg_list) = call.arg_list()
91 && let Some(arg) = arg_list.args().next()
92 {
93 editor.replace(call.syntax(), arg.syntax());
94 }
95 });
96 let edit = editor.finish();
97 let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap();
98
99 let mut editor = builder.make_editor(expr.syntax());
100 let make = SyntaxFactory::with_mappings();
101 let closure_body = match closure_body {
102 ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
103 e => e,
104 };
105
106 let parenthesize = matches!(
107 cond,
108 ast::Expr::BinExpr(_)
109 | ast::Expr::BlockExpr(_)
110 | ast::Expr::BreakExpr(_)
111 | ast::Expr::CastExpr(_)
112 | ast::Expr::ClosureExpr(_)
113 | ast::Expr::ContinueExpr(_)
114 | ast::Expr::ForExpr(_)
115 | ast::Expr::IfExpr(_)
116 | ast::Expr::LoopExpr(_)
117 | ast::Expr::MacroExpr(_)
118 | ast::Expr::MatchExpr(_)
119 | ast::Expr::PrefixExpr(_)
120 | ast::Expr::RangeExpr(_)
121 | ast::Expr::RefExpr(_)
122 | ast::Expr::ReturnExpr(_)
123 | ast::Expr::WhileExpr(_)
124 | ast::Expr::YieldExpr(_)
125 );
126 let cond = if invert_cond {
127 invert_boolean_expression(&make, cond)
128 } else {
129 cond.clone_for_update()
130 };
131 let cond = if parenthesize { make.expr_paren(cond).into() } else { cond };
132 let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into()));
133 let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list);
134 editor.replace(expr.syntax(), mcall.syntax());
135
136 editor.add_mappings(make.finish_with_mappings());
137 builder.add_file_edits(ctx.vfs_file_id(), editor);
138 },
139 )
140}
141
142pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
163 let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
164 let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
165 let receiver = mcall.receiver()?;
166 let closure_body = Itertools::exactly_one(mcall.arg_list()?.args()).ok()?;
168 let closure_body = match closure_body {
169 ast::Expr::ClosureExpr(expr) => expr.body()?,
170 _ => return None,
171 };
172 let func = ctx.sema.resolve_method_call(&mcall)?;
174 if func.name(ctx.sema.db) != sym::then {
175 return None;
176 }
177 let assoc = func.as_assoc_item(ctx.sema.db)?;
178 if !assoc.implementing_ty(ctx.sema.db)?.is_bool() {
179 return None;
180 }
181
182 let target = mcall.syntax().text_range();
183 acc.add(
184 AssistId::refactor_rewrite("convert_bool_then_to_if"),
185 "Convert `bool::then` call to `if`",
186 target,
187 |builder| {
188 let mapless_make = SyntaxFactory::without_mappings();
189 let closure_body = match closure_body.reset_indent() {
190 ast::Expr::BlockExpr(block) => block,
191 e => mapless_make.block_expr(None, Some(e)),
192 };
193
194 let closure_body = closure_body.clone_subtree();
195 let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
196 let none_path = mapless_make.expr_path(mapless_make.ident_path("None"));
198 let some_path = mapless_make.expr_path(mapless_make.ident_path("Some"));
199 for_each_tail_expr(&ast::Expr::BlockExpr(closure_body), &mut |e| {
200 let e = match e {
201 ast::Expr::BreakExpr(e) => e.expr(),
202 ast::Expr::ReturnExpr(e) => e.expr(),
203 _ => Some(e.clone()),
204 };
205 if let Some(expr) = e {
206 editor.replace(
207 expr.syntax().clone(),
208 mapless_make
209 .expr_call(some_path.clone(), mapless_make.arg_list(Some(expr)))
210 .syntax()
211 .clone(),
212 );
213 }
214 });
215 let edit = editor.finish();
216 let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap();
217
218 let mut editor = builder.make_editor(mcall.syntax());
219 let make = SyntaxFactory::with_mappings();
220
221 let cond = match &receiver {
222 ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
223 _ => receiver,
224 };
225 let if_expr = make
226 .expr_if(
227 cond,
228 closure_body,
229 Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))),
230 )
231 .indent(mcall.indent_level());
232 editor.replace(mcall.syntax().clone(), if_expr.syntax().clone());
233
234 editor.add_mappings(make.finish_with_mappings());
235 builder.add_file_edits(ctx.vfs_file_id(), editor);
236 },
237 )
238}
239
240fn option_variants(
241 sema: &Semantics<'_, RootDatabase>,
242 expr: &SyntaxNode,
243) -> Option<(hir::Variant, hir::Variant)> {
244 let fam = FamousDefs(sema, sema.scope(expr)?.krate());
245 let option_variants = fam.core_option_Option()?.variants(sema.db);
246 match &*option_variants {
247 &[variant0, variant1] => Some(if variant0.name(sema.db) == sym::None {
248 (variant0, variant1)
249 } else {
250 (variant1, variant0)
251 }),
252 _ => None,
253 }
254}
255
256fn is_invalid_body(
259 sema: &Semantics<'_, RootDatabase>,
260 some_variant: hir::Variant,
261 expr: &ast::Expr,
262) -> bool {
263 let mut invalid = false;
264 preorder_expr(expr, &mut |e| {
265 invalid |=
266 matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
267 invalid
268 });
269 if !invalid {
270 for_each_tail_expr(expr, &mut |e| {
271 if invalid {
272 return;
273 }
274 let e = match e {
275 ast::Expr::BreakExpr(e) => e.expr(),
276 e @ ast::Expr::CallExpr(_) => Some(e.clone()),
277 _ => None,
278 };
279 if let Some(ast::Expr::CallExpr(call)) = e
280 && let Some(ast::Expr::PathExpr(p)) = call.expr()
281 {
282 let res = p.path().and_then(|p| sema.resolve_path(&p));
283 if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
284 return invalid |= v != some_variant;
285 }
286 }
287 invalid = true
288 });
289 }
290 invalid
291}
292
293fn block_is_none_variant(
294 sema: &Semantics<'_, RootDatabase>,
295 block: &ast::BlockExpr,
296 none_variant: hir::Variant,
297) -> bool {
298 block_as_lone_tail(block).and_then(|e| match e {
299 ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
300 hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
301 _ => None,
302 },
303 _ => None,
304 }) == Some(none_variant)
305}
306
307#[cfg(test)]
308mod tests {
309 use crate::tests::{check_assist, check_assist_not_applicable};
310
311 use super::*;
312
313 #[test]
314 fn convert_if_to_bool_then_simple() {
315 check_assist(
316 convert_if_to_bool_then,
317 r"
318//- minicore:option
319fn main() {
320 if$0 true {
321 Some(15)
322 } else {
323 None
324 }
325}
326",
327 r"
328fn main() {
329 true.then(|| 15)
330}
331",
332 );
333 }
334
335 #[test]
336 fn convert_if_to_bool_then_invert() {
337 check_assist(
338 convert_if_to_bool_then,
339 r"
340//- minicore:option
341fn main() {
342 if$0 true {
343 None
344 } else {
345 Some(15)
346 }
347}
348",
349 r"
350fn main() {
351 false.then(|| 15)
352}
353",
354 );
355 }
356
357 #[test]
358 fn convert_if_to_bool_then_none_none() {
359 check_assist_not_applicable(
360 convert_if_to_bool_then,
361 r"
362//- minicore:option
363fn main() {
364 if$0 true {
365 None
366 } else {
367 None
368 }
369}
370",
371 );
372 }
373
374 #[test]
375 fn convert_if_to_bool_then_some_some() {
376 check_assist_not_applicable(
377 convert_if_to_bool_then,
378 r"
379//- minicore:option
380fn main() {
381 if$0 true {
382 Some(15)
383 } else {
384 Some(15)
385 }
386}
387",
388 );
389 }
390
391 #[test]
392 fn convert_if_to_bool_then_mixed() {
393 check_assist_not_applicable(
394 convert_if_to_bool_then,
395 r"
396//- minicore:option
397fn main() {
398 if$0 true {
399 if true {
400 Some(15)
401 } else {
402 None
403 }
404 } else {
405 None
406 }
407}
408",
409 );
410 }
411
412 #[test]
413 fn convert_if_to_bool_then_chain() {
414 cov_mark::check!(convert_if_to_bool_then_chain);
415 check_assist_not_applicable(
416 convert_if_to_bool_then,
417 r"
418//- minicore:option
419fn main() {
420 if$0 true {
421 Some(15)
422 } else if true {
423 None
424 } else {
425 None
426 }
427}
428",
429 );
430 }
431
432 #[test]
433 fn convert_if_to_bool_then_pattern_cond() {
434 check_assist_not_applicable(
435 convert_if_to_bool_then,
436 r"
437//- minicore:option
438fn main() {
439 if$0 let true = true {
440 Some(15)
441 } else {
442 None
443 }
444}
445",
446 );
447 }
448
449 #[test]
450 fn convert_if_to_bool_then_pattern_invalid_body() {
451 cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
452 check_assist_not_applicable(
453 convert_if_to_bool_then,
454 r"
455//- minicore:option
456fn make_me_an_option() -> Option<i32> { None }
457fn main() {
458 if$0 true {
459 if true {
460 make_me_an_option()
461 } else {
462 Some(15)
463 }
464 } else {
465 None
466 }
467}
468",
469 );
470 check_assist_not_applicable(
471 convert_if_to_bool_then,
472 r"
473//- minicore:option
474fn main() {
475 if$0 true {
476 if true {
477 return;
478 }
479 Some(15)
480 } else {
481 None
482 }
483}
484",
485 );
486 }
487
488 #[test]
489 fn convert_bool_then_to_if_inapplicable() {
490 check_assist_not_applicable(
491 convert_bool_then_to_if,
492 r"
493//- minicore:bool_impl
494fn main() {
495 0.t$0hen(|| 15);
496}
497",
498 );
499 check_assist_not_applicable(
500 convert_bool_then_to_if,
501 r"
502//- minicore:bool_impl
503fn main() {
504 true.t$0hen(15);
505}
506",
507 );
508 check_assist_not_applicable(
509 convert_bool_then_to_if,
510 r"
511//- minicore:bool_impl
512fn main() {
513 true.t$0hen(|| 15, 15);
514}
515",
516 );
517 }
518
519 #[test]
520 fn convert_bool_then_to_if_simple() {
521 check_assist(
522 convert_bool_then_to_if,
523 r"
524//- minicore:bool_impl
525fn main() {
526 true.t$0hen(|| 15)
527}
528",
529 r"
530fn main() {
531 if true {
532 Some(15)
533 } else {
534 None
535 }
536}
537",
538 );
539 check_assist(
540 convert_bool_then_to_if,
541 r"
542//- minicore:bool_impl
543fn main() {
544 true.t$0hen(|| {
545 15
546 })
547}
548",
549 r"
550fn main() {
551 if true {
552 Some(15)
553 } else {
554 None
555 }
556}
557",
558 );
559 }
560
561 #[test]
562 fn convert_bool_then_to_if_tails() {
563 check_assist(
564 convert_bool_then_to_if,
565 r"
566//- minicore:bool_impl
567fn main() {
568 true.t$0hen(|| {
569 loop {
570 if false {
571 break 0;
572 }
573 break 15;
574 }
575 })
576}
577",
578 r"
579fn main() {
580 if true {
581 loop {
582 if false {
583 break Some(0);
584 }
585 break Some(15);
586 }
587 } else {
588 None
589 }
590}
591",
592 );
593 }
594}