1use ide_db::assists::{AssistId, GroupLabel};
2use syntax::{
3 AstNode,
4 ast::{self, ArithOp, BinaryOp},
5};
6
7use crate::{
8 assist_context::{AssistContext, Assists},
9 utils::wrap_paren,
10};
11
12pub(crate) fn replace_arith_with_checked(
28 acc: &mut Assists,
29 ctx: &AssistContext<'_, '_>,
30) -> Option<()> {
31 replace_arith(acc, ctx, ArithKind::Checked)
32}
33
34pub(crate) fn replace_arith_with_saturating(
50 acc: &mut Assists,
51 ctx: &AssistContext<'_, '_>,
52) -> Option<()> {
53 replace_arith(acc, ctx, ArithKind::Saturating)
54}
55
56pub(crate) fn replace_arith_with_wrapping(
72 acc: &mut Assists,
73 ctx: &AssistContext<'_, '_>,
74) -> Option<()> {
75 replace_arith(acc, ctx, ArithKind::Wrapping)
76}
77
78fn replace_arith(acc: &mut Assists, ctx: &AssistContext<'_, '_>, kind: ArithKind) -> Option<()> {
79 let (lhs, op, is_assign, rhs) = parse_binary_op(ctx)?;
80 let op_expr = lhs.syntax().parent()?;
81
82 if !is_primitive_int(ctx, &lhs) || !is_primitive_int(ctx, &rhs) {
83 return None;
84 }
85
86 acc.add_group(
87 &GroupLabel("Replace arithmetic...".into()),
88 kind.assist_id(),
89 kind.label(),
90 op_expr.text_range(),
91 |builder| {
92 let editor = builder.make_editor(rhs.syntax());
93 let make = editor.make();
94 let method_name = kind.method_name(op);
95
96 let receiver = wrap_paren(lhs.clone(), make, ast::prec::ExprPrecedence::Postfix);
97 let mut arith_expr = make
98 .expr_method_call(receiver, make.name_ref(&method_name), make.arg_list([rhs]))
99 .into();
100 if is_assign {
101 arith_expr = make.expr_assignment(lhs, arith_expr).into();
102 }
103 editor.replace(op_expr, arith_expr.syntax());
104 builder.add_file_edits(ctx.vfs_file_id(), editor);
105 },
106 )
107}
108
109fn is_primitive_int(ctx: &AssistContext<'_, '_>, expr: &ast::Expr) -> bool {
110 match ctx.sema.type_of_expr(expr) {
111 Some(ty) => ty.adjusted().is_int_or_uint(),
112 _ => false,
113 }
114}
115
116fn parse_binary_op(ctx: &AssistContext<'_, '_>) -> Option<(ast::Expr, ArithOp, bool, ast::Expr)> {
118 if !ctx.has_empty_selection() {
119 return None;
120 }
121 let expr = ctx.find_node_at_offset::<ast::BinExpr>()?;
122
123 let (op, is_assign) = match expr.op_kind()? {
124 BinaryOp::ArithOp(arith_op) => (arith_op, false),
125 BinaryOp::Assignment { op: Some(op) } => (op, true),
126 _ => return None,
127 };
128 if !matches!(op, ArithOp::Add | ArithOp::Sub | ArithOp::Mul | ArithOp::Div) {
129 return None;
130 }
131
132 let lhs = expr.lhs()?;
133 let rhs = expr.rhs()?;
134
135 Some((lhs, op, is_assign, rhs))
136}
137
138pub(crate) enum ArithKind {
139 Saturating,
140 Wrapping,
141 Checked,
142}
143
144impl ArithKind {
145 fn assist_id(&self) -> AssistId {
146 let s = match self {
147 ArithKind::Saturating => "replace_arith_with_saturating",
148 ArithKind::Checked => "replace_arith_with_checked",
149 ArithKind::Wrapping => "replace_arith_with_wrapping",
150 };
151
152 AssistId::refactor_rewrite(s)
153 }
154
155 fn label(&self) -> &'static str {
156 match self {
157 ArithKind::Saturating => "Replace arithmetic with call to saturating_*",
158 ArithKind::Checked => "Replace arithmetic with call to checked_*",
159 ArithKind::Wrapping => "Replace arithmetic with call to wrapping_*",
160 }
161 }
162
163 fn method_name(&self, op: ArithOp) -> String {
164 let prefix = match self {
165 ArithKind::Checked => "checked_",
166 ArithKind::Wrapping => "wrapping_",
167 ArithKind::Saturating => "saturating_",
168 };
169
170 let suffix = match op {
171 ArithOp::Add => "add",
172 ArithOp::Sub => "sub",
173 ArithOp::Mul => "mul",
174 ArithOp::Div => "div",
175 _ => unreachable!("this function should only be called with +, -, / or *"),
176 };
177 format!("{prefix}{suffix}")
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use crate::tests::{check_assist, check_assist_not_applicable};
184
185 use super::*;
186
187 #[test]
188 fn arith_kind_method_name() {
189 assert_eq!(ArithKind::Saturating.method_name(ArithOp::Add), "saturating_add");
190 assert_eq!(ArithKind::Checked.method_name(ArithOp::Sub), "checked_sub");
191 }
192
193 #[test]
194 fn replace_arith_with_checked_add() {
195 check_assist(
196 replace_arith_with_checked,
197 r#"
198fn main() {
199 let x = 1 $0+ 2;
200}
201"#,
202 r#"
203fn main() {
204 let x = 1.checked_add(2);
205}
206"#,
207 )
208 }
209
210 #[test]
211 fn replace_arith_with_saturating_add() {
212 check_assist(
213 replace_arith_with_saturating,
214 r#"
215fn main() {
216 let x = 1 $0+ 2;
217}
218"#,
219 r#"
220fn main() {
221 let x = 1.saturating_add(2);
222}
223"#,
224 )
225 }
226
227 #[test]
228 fn replace_arith_with_wrapping_add() {
229 check_assist(
230 replace_arith_with_wrapping,
231 r#"
232fn main() {
233 let x = 1 $0+ 2;
234}
235"#,
236 r#"
237fn main() {
238 let x = 1.wrapping_add(2);
239}
240"#,
241 )
242 }
243
244 #[test]
245 fn replace_arith_with_wrapping_add_add_parenthesis() {
246 check_assist(
247 replace_arith_with_wrapping,
248 r#"
249fn main() {
250 let x = 1*3 $0+ 2;
251}
252"#,
253 r#"
254fn main() {
255 let x = (1*3).wrapping_add(2);
256}
257"#,
258 )
259 }
260
261 #[test]
262 fn replace_arith_with_wrapping_add_assign() {
263 check_assist(
264 replace_arith_with_wrapping,
265 r#"
266fn main() {
267 let mut x = 1;
268 x $0+= 2;
269}
270"#,
271 r#"
272fn main() {
273 let mut x = 1;
274 x = x.wrapping_add(2);
275}
276"#,
277 )
278 }
279
280 #[test]
281 fn replace_arith_not_applicable_with_non_empty_selection() {
282 check_assist_not_applicable(
283 replace_arith_with_checked,
284 r#"
285fn main() {
286 let x = 1 $0+$0 2;
287}
288"#,
289 )
290 }
291}