use hir::{AsAssocItem, Semantics, sym};
use ide_db::{
RootDatabase,
famous_defs::FamousDefs,
syntax_helpers::node_ext::{
block_as_lone_tail, for_each_tail_expr, is_pattern_cond, preorder_expr,
},
};
use itertools::Itertools;
use syntax::{
AstNode, SyntaxNode,
ast::{self, HasArgList, edit::AstNodeEdit, syntax_factory::SyntaxFactory},
syntax_editor::SyntaxEditor,
};
use crate::{
AssistContext, AssistId, Assists,
utils::{invert_boolean_expression, unwrap_trivial_block},
};
pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
return None;
}
let cond = expr.condition().filter(|cond| !is_pattern_cond(cond.clone()))?;
let then = expr.then_branch()?;
let else_ = match expr.else_branch()? {
ast::ElseBranch::Block(b) => b,
ast::ElseBranch::IfExpr(_) => {
cov_mark::hit!(convert_if_to_bool_then_chain);
return None;
}
};
let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
let (invert_cond, closure_body) = match (
block_is_none_variant(&ctx.sema, &then, none_variant),
block_is_none_variant(&ctx.sema, &else_, none_variant),
) {
(invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
(invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
_ => return None,
};
if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
return None;
}
let target = expr.syntax().text_range();
acc.add(
AssistId::refactor_rewrite("convert_if_to_bool_then"),
"Convert `if` expression to `bool::then` call",
target,
|builder| {
let closure_body = closure_body.clone_subtree();
let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
for_each_tail_expr(&closure_body, &mut |e| {
let e = match e {
ast::Expr::BreakExpr(e) => e.expr(),
e @ ast::Expr::CallExpr(_) => Some(e.clone()),
_ => None,
};
if let Some(ast::Expr::CallExpr(call)) = e {
if let Some(arg_list) = call.arg_list() {
if let Some(arg) = arg_list.args().next() {
editor.replace(call.syntax(), arg.syntax());
}
}
}
});
let edit = editor.finish();
let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap();
let mut editor = builder.make_editor(expr.syntax());
let make = SyntaxFactory::new();
let closure_body = match closure_body {
ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
e => e,
};
let parenthesize = matches!(
cond,
ast::Expr::BinExpr(_)
| ast::Expr::BlockExpr(_)
| ast::Expr::BreakExpr(_)
| ast::Expr::CastExpr(_)
| ast::Expr::ClosureExpr(_)
| ast::Expr::ContinueExpr(_)
| ast::Expr::ForExpr(_)
| ast::Expr::IfExpr(_)
| ast::Expr::LoopExpr(_)
| ast::Expr::MacroExpr(_)
| ast::Expr::MatchExpr(_)
| ast::Expr::PrefixExpr(_)
| ast::Expr::RangeExpr(_)
| ast::Expr::RefExpr(_)
| ast::Expr::ReturnExpr(_)
| ast::Expr::WhileExpr(_)
| ast::Expr::YieldExpr(_)
);
let cond = if invert_cond {
invert_boolean_expression(&make, cond)
} else {
cond.clone_for_update()
};
let cond = if parenthesize { make.expr_paren(cond).into() } else { cond };
let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into()));
let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list);
editor.replace(expr.syntax(), mcall.syntax());
editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.file_id(), editor);
},
)
}
pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let name_ref = ctx.find_node_at_offset::<ast::NameRef>()?;
let mcall = name_ref.syntax().parent().and_then(ast::MethodCallExpr::cast)?;
let receiver = mcall.receiver()?;
let closure_body = mcall.arg_list()?.args().exactly_one().ok()?;
let closure_body = match closure_body {
ast::Expr::ClosureExpr(expr) => expr.body()?,
_ => return None,
};
let func = ctx.sema.resolve_method_call(&mcall)?;
if func.name(ctx.sema.db) != sym::then {
return None;
}
let assoc = func.as_assoc_item(ctx.sema.db)?;
if !assoc.implementing_ty(ctx.sema.db)?.is_bool() {
return None;
}
let target = mcall.syntax().text_range();
acc.add(
AssistId::refactor_rewrite("convert_bool_then_to_if"),
"Convert `bool::then` call to `if`",
target,
|builder| {
let mapless_make = SyntaxFactory::without_mappings();
let closure_body = match closure_body.reset_indent() {
ast::Expr::BlockExpr(block) => block,
e => mapless_make.block_expr(None, Some(e)),
};
let closure_body = closure_body.clone_subtree();
let mut editor = SyntaxEditor::new(closure_body.syntax().clone());
let none_path = mapless_make.expr_path(mapless_make.ident_path("None"));
let some_path = mapless_make.expr_path(mapless_make.ident_path("Some"));
for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| {
let e = match e {
ast::Expr::BreakExpr(e) => e.expr(),
ast::Expr::ReturnExpr(e) => e.expr(),
_ => Some(e.clone()),
};
if let Some(expr) = e {
editor.replace(
expr.syntax().clone(),
mapless_make
.expr_call(some_path.clone(), mapless_make.arg_list(Some(expr)))
.syntax()
.clone(),
);
}
});
let edit = editor.finish();
let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap();
let mut editor = builder.make_editor(mcall.syntax());
let make = SyntaxFactory::new();
let cond = match &receiver {
ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver),
_ => receiver,
};
let if_expr = make
.expr_if(
cond,
closure_body,
Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))),
)
.indent(mcall.indent_level())
.clone_for_update();
editor.replace(mcall.syntax().clone(), if_expr.syntax().clone());
editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.file_id(), editor);
},
)
}
fn option_variants(
sema: &Semantics<'_, RootDatabase>,
expr: &SyntaxNode,
) -> Option<(hir::Variant, hir::Variant)> {
let fam = FamousDefs(sema, sema.scope(expr)?.krate());
let option_variants = fam.core_option_Option()?.variants(sema.db);
match &*option_variants {
&[variant0, variant1] => Some(if variant0.name(sema.db) == sym::None.clone() {
(variant0, variant1)
} else {
(variant1, variant0)
}),
_ => None,
}
}
fn is_invalid_body(
sema: &Semantics<'_, RootDatabase>,
some_variant: hir::Variant,
expr: &ast::Expr,
) -> bool {
let mut invalid = false;
preorder_expr(expr, &mut |e| {
invalid |=
matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
invalid
});
if !invalid {
for_each_tail_expr(expr, &mut |e| {
if invalid {
return;
}
let e = match e {
ast::Expr::BreakExpr(e) => e.expr(),
e @ ast::Expr::CallExpr(_) => Some(e.clone()),
_ => None,
};
if let Some(ast::Expr::CallExpr(call)) = e {
if let Some(ast::Expr::PathExpr(p)) = call.expr() {
let res = p.path().and_then(|p| sema.resolve_path(&p));
if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
return invalid |= v != some_variant;
}
}
}
invalid = true
});
}
invalid
}
fn block_is_none_variant(
sema: &Semantics<'_, RootDatabase>,
block: &ast::BlockExpr,
none_variant: hir::Variant,
) -> bool {
block_as_lone_tail(block).and_then(|e| match e {
ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
_ => None,
},
_ => None,
}) == Some(none_variant)
}
#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
use super::*;
#[test]
fn convert_if_to_bool_then_simple() {
check_assist(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
Some(15)
} else {
None
}
}
",
r"
fn main() {
true.then(|| 15)
}
",
);
}
#[test]
fn convert_if_to_bool_then_invert() {
check_assist(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
None
} else {
Some(15)
}
}
",
r"
fn main() {
false.then(|| 15)
}
",
);
}
#[test]
fn convert_if_to_bool_then_none_none() {
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
None
} else {
None
}
}
",
);
}
#[test]
fn convert_if_to_bool_then_some_some() {
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
Some(15)
} else {
Some(15)
}
}
",
);
}
#[test]
fn convert_if_to_bool_then_mixed() {
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
if true {
Some(15)
} else {
None
}
} else {
None
}
}
",
);
}
#[test]
fn convert_if_to_bool_then_chain() {
cov_mark::check!(convert_if_to_bool_then_chain);
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
Some(15)
} else if true {
None
} else {
None
}
}
",
);
}
#[test]
fn convert_if_to_bool_then_pattern_cond() {
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 let true = true {
Some(15)
} else {
None
}
}
",
);
}
#[test]
fn convert_if_to_bool_then_pattern_invalid_body() {
cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn make_me_an_option() -> Option<i32> { None }
fn main() {
if$0 true {
if true {
make_me_an_option()
} else {
Some(15)
}
} else {
None
}
}
",
);
check_assist_not_applicable(
convert_if_to_bool_then,
r"
//- minicore:option
fn main() {
if$0 true {
if true {
return;
}
Some(15)
} else {
None
}
}
",
);
}
#[test]
fn convert_bool_then_to_if_inapplicable() {
check_assist_not_applicable(
convert_bool_then_to_if,
r"
//- minicore:bool_impl
fn main() {
0.t$0hen(|| 15);
}
",
);
check_assist_not_applicable(
convert_bool_then_to_if,
r"
//- minicore:bool_impl
fn main() {
true.t$0hen(15);
}
",
);
check_assist_not_applicable(
convert_bool_then_to_if,
r"
//- minicore:bool_impl
fn main() {
true.t$0hen(|| 15, 15);
}
",
);
}
#[test]
fn convert_bool_then_to_if_simple() {
check_assist(
convert_bool_then_to_if,
r"
//- minicore:bool_impl
fn main() {
true.t$0hen(|| 15)
}
",
r"
fn main() {
if true {
Some(15)
} else {
None
}
}
",
);
check_assist(
convert_bool_then_to_if,
r"
//- minicore:bool_impl
fn main() {
true.t$0hen(|| {
15
})
}
",
r"
fn main() {
if true {
Some(15)
} else {
None
}
}
",
);
}
#[test]
fn convert_bool_then_to_if_tails() {
check_assist(
convert_bool_then_to_if,
r"
//- minicore:bool_impl
fn main() {
true.t$0hen(|| {
loop {
if false {
break 0;
}
break 15;
}
})
}
",
r"
fn main() {
if true {
loop {
if false {
break Some(0);
}
break Some(15);
}
} else {
None
}
}
",
);
}
}