1use hir::Semantics;
2use ide_db::RootDatabase;
3use stdx::format_to;
4use syntax::ast::{self, AstNode};
5
6use crate::{AssistContext, AssistId, Assists};
7
8pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
27 acc: &mut Assists,
28 ctx: &AssistContext<'_>,
29) -> Option<()> {
30 use ArmBodyExpression::*;
31 let match_expr = ctx.find_node_at_offset::<ast::MatchExpr>()?;
32 let match_arm_list = match_expr.match_arm_list()?;
33 let mut arms = match_arm_list.arms();
34 let first_arm = arms.next()?;
35 let second_arm = arms.next()?;
36 if arms.next().is_some() {
37 cov_mark::hit!(non_two_arm_match);
38 return None;
39 }
40 let first_arm_expr = first_arm.expr()?;
41 let second_arm_expr = second_arm.expr()?;
42 let first_arm_body = is_bool_literal_expr(&ctx.sema, &first_arm_expr)?;
43 let second_arm_body = is_bool_literal_expr(&ctx.sema, &second_arm_expr)?;
44
45 if !matches!(
46 (&first_arm_body, &second_arm_body),
47 (Literal(true), Literal(false))
48 | (Literal(false), Literal(true))
49 | (Expression(_), Literal(false))
50 ) {
51 cov_mark::hit!(non_invert_bool_literal_arms);
52 return None;
53 }
54
55 let target_range = ctx.sema.original_range(match_expr.syntax()).range;
56 let expr = match_expr.expr()?;
57
58 acc.add(
59 AssistId::refactor_rewrite("convert_two_arm_bool_match_to_matches_macro"),
60 "Convert to matches!",
61 target_range,
62 |builder| {
63 let mut arm_str = String::new();
64 if let Some(pat) = &first_arm.pat() {
65 format_to!(arm_str, "{pat}");
66 }
67 if let Some(guard) = &first_arm.guard() {
68 arm_str += &format!(" {guard}");
69 }
70
71 let replace_with = match (first_arm_body, second_arm_body) {
72 (Literal(true), Literal(false)) => {
73 format!("matches!({expr}, {arm_str})")
74 }
75 (Literal(false), Literal(true)) => {
76 format!("!matches!({expr}, {arm_str})")
77 }
78 (Expression(body_expr), Literal(false)) => {
79 arm_str.push_str(match &first_arm.guard() {
80 Some(_) => " && ",
81 _ => " if ",
82 });
83 format!("matches!({expr}, {arm_str}{body_expr})")
84 }
85 _ => {
86 unreachable!()
87 }
88 };
89 builder.replace(target_range, replace_with);
90 },
91 )
92}
93
94enum ArmBodyExpression {
95 Literal(bool),
96 Expression(ast::Expr),
97}
98
99fn is_bool_literal_expr(
100 sema: &Semantics<'_, RootDatabase>,
101 expr: &ast::Expr,
102) -> Option<ArmBodyExpression> {
103 if let ast::Expr::Literal(lit) = expr
104 && let ast::LiteralKind::Bool(b) = lit.kind()
105 {
106 return Some(ArmBodyExpression::Literal(b));
107 }
108
109 if !sema.type_of_expr(expr)?.original.is_bool() {
110 return None;
111 }
112
113 Some(ArmBodyExpression::Expression(expr.clone()))
114}
115
116#[cfg(test)]
117mod tests {
118 use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
119
120 use super::convert_two_arm_bool_match_to_matches_macro;
121
122 #[test]
123 fn not_applicable_outside_of_range_left() {
124 check_assist_not_applicable(
125 convert_two_arm_bool_match_to_matches_macro,
126 r#"
127fn foo(a: Option<u32>) -> bool {
128 $0 match a {
129 Some(_val) => true,
130 _ => false
131 }
132}
133 "#,
134 );
135 }
136
137 #[test]
138 fn not_applicable_non_two_arm_match() {
139 cov_mark::check!(non_two_arm_match);
140 check_assist_not_applicable(
141 convert_two_arm_bool_match_to_matches_macro,
142 r#"
143fn foo(a: Option<u32>) -> bool {
144 match a$0 {
145 Some(3) => true,
146 Some(4) => true,
147 _ => false
148 }
149}
150 "#,
151 );
152 }
153
154 #[test]
155 fn not_applicable_both_false_arms() {
156 cov_mark::check!(non_invert_bool_literal_arms);
157 check_assist_not_applicable(
158 convert_two_arm_bool_match_to_matches_macro,
159 r#"
160fn foo(a: Option<u32>) -> bool {
161 match a$0 {
162 Some(val) => false,
163 _ => false
164 }
165}
166 "#,
167 );
168 }
169
170 #[test]
171 fn not_applicable_both_true_arms() {
172 cov_mark::check!(non_invert_bool_literal_arms);
173 check_assist_not_applicable(
174 convert_two_arm_bool_match_to_matches_macro,
175 r#"
176fn foo(a: Option<u32>) -> bool {
177 match a$0 {
178 Some(val) => true,
179 _ => true
180 }
181}
182 "#,
183 );
184 }
185
186 #[test]
187 fn convert_simple_case() {
188 check_assist(
189 convert_two_arm_bool_match_to_matches_macro,
190 r#"
191fn foo(a: Option<u32>) -> bool {
192 match a$0 {
193 Some(_val) => true,
194 _ => false
195 }
196}
197"#,
198 r#"
199fn foo(a: Option<u32>) -> bool {
200 matches!(a, Some(_val))
201}
202"#,
203 );
204 }
205
206 #[test]
207 fn convert_simple_invert_case() {
208 check_assist(
209 convert_two_arm_bool_match_to_matches_macro,
210 r#"
211fn foo(a: Option<u32>) -> bool {
212 match a$0 {
213 Some(_val) => false,
214 _ => true
215 }
216}
217"#,
218 r#"
219fn foo(a: Option<u32>) -> bool {
220 !matches!(a, Some(_val))
221}
222"#,
223 );
224 }
225
226 #[test]
227 fn convert_with_guard_case() {
228 check_assist(
229 convert_two_arm_bool_match_to_matches_macro,
230 r#"
231fn foo(a: Option<u32>) -> bool {
232 match a$0 {
233 Some(val) if val > 3 => true,
234 _ => false
235 }
236}
237"#,
238 r#"
239fn foo(a: Option<u32>) -> bool {
240 matches!(a, Some(val) if val > 3)
241}
242"#,
243 );
244 }
245
246 #[test]
247 fn convert_enum_match_cases() {
248 check_assist(
249 convert_two_arm_bool_match_to_matches_macro,
250 r#"
251enum X { A, B }
252
253fn foo(a: X) -> bool {
254 match a$0 {
255 X::A => true,
256 _ => false
257 }
258}
259"#,
260 r#"
261enum X { A, B }
262
263fn foo(a: X) -> bool {
264 matches!(a, X::A)
265}
266"#,
267 );
268 }
269
270 #[test]
271 fn convert_target_simple() {
272 check_assist_target(
273 convert_two_arm_bool_match_to_matches_macro,
274 r#"
275fn foo(a: Option<u32>) -> bool {
276 match a$0 {
277 Some(val) => true,
278 _ => false
279 }
280}
281"#,
282 r#"match a {
283 Some(val) => true,
284 _ => false
285 }"#,
286 );
287 }
288
289 #[test]
290 fn convert_target_complex() {
291 check_assist_target(
292 convert_two_arm_bool_match_to_matches_macro,
293 r#"
294enum E { X, Y }
295
296fn main() {
297 match E::X$0 {
298 E::X => true,
299 _ => false,
300 }
301}
302"#,
303 "match E::X {
304 E::X => true,
305 _ => false,
306 }",
307 );
308 }
309
310 #[test]
311 fn convert_non_literal_bool() {
312 check_assist(
313 convert_two_arm_bool_match_to_matches_macro,
314 r#"
315fn main() {
316 match 0$0 {
317 a @ 0..15 => a == 0,
318 _ => false,
319 }
320}
321"#,
322 r#"
323fn main() {
324 matches!(0, a @ 0..15 if a == 0)
325}
326"#,
327 );
328 check_assist(
329 convert_two_arm_bool_match_to_matches_macro,
330 r#"
331fn main() {
332 match 0$0 {
333 a @ 0..15 if thing() => a == 0,
334 _ => false,
335 }
336}
337"#,
338 r#"
339fn main() {
340 matches!(0, a @ 0..15 if thing() && a == 0)
341}
342"#,
343 );
344 }
345}