1use chalk_ir::{Mutability, cast::Cast};
5use hir_def::{
6 hir::{
7 Array, AsmOperand, BinaryOp, BindingAnnotation, Expr, ExprId, Pat, PatId, Statement,
8 UnaryOp,
9 },
10 lang_item::LangItem,
11};
12use hir_expand::name::Name;
13use intern::sym;
14
15use crate::{
16 Adjust, Adjustment, AutoBorrow, Interner, OverloadedDeref, TyBuilder, TyKind,
17 infer::{Expectation, InferenceContext, expr::ExprIsRead},
18 lower::lower_to_chalk_mutability,
19};
20
21impl InferenceContext<'_> {
22 pub(crate) fn infer_mut_body(&mut self) {
23 self.infer_mut_expr(self.body.body_expr, Mutability::Not);
24 }
25
26 fn infer_mut_expr(&mut self, tgt_expr: ExprId, mut mutability: Mutability) {
27 if let Some(adjustments) = self.result.expr_adjustments.get_mut(&tgt_expr) {
28 for adj in adjustments.iter_mut().rev() {
29 match &mut adj.kind {
30 Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => (),
31 Adjust::Deref(Some(d)) => *d = OverloadedDeref(Some(mutability)),
32 Adjust::Borrow(b) => match b {
33 AutoBorrow::Ref(_, m) | AutoBorrow::RawPtr(m) => mutability = *m,
34 },
35 }
36 }
37 }
38 self.infer_mut_expr_without_adjust(tgt_expr, mutability);
39 }
40
41 fn infer_mut_expr_without_adjust(&mut self, tgt_expr: ExprId, mutability: Mutability) {
42 match &self.body[tgt_expr] {
43 Expr::Missing => (),
44 Expr::InlineAsm(e) => {
45 e.operands.iter().for_each(|(_, op)| match op {
46 AsmOperand::In { expr, .. }
47 | AsmOperand::Out { expr: Some(expr), .. }
48 | AsmOperand::InOut { expr, .. } => {
49 self.infer_mut_expr_without_adjust(*expr, Mutability::Not)
50 }
51 AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
52 self.infer_mut_expr_without_adjust(*in_expr, Mutability::Not);
53 if let Some(out_expr) = out_expr {
54 self.infer_mut_expr_without_adjust(*out_expr, Mutability::Not);
55 }
56 }
57 AsmOperand::Out { expr: None, .. }
58 | AsmOperand::Label(_)
59 | AsmOperand::Sym(_)
60 | AsmOperand::Const(_) => (),
61 });
62 }
63 Expr::OffsetOf(_) => (),
64 &Expr::If { condition, then_branch, else_branch } => {
65 self.infer_mut_expr(condition, Mutability::Not);
66 self.infer_mut_expr(then_branch, Mutability::Not);
67 if let Some(else_branch) = else_branch {
68 self.infer_mut_expr(else_branch, Mutability::Not);
69 }
70 }
71 Expr::Const(id) => {
72 self.infer_mut_expr(*id, Mutability::Not);
73 }
74 Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)),
75 Expr::Block { id: _, statements, tail, label: _ }
76 | Expr::Async { id: _, statements, tail }
77 | Expr::Unsafe { id: _, statements, tail } => {
78 for st in statements.iter() {
79 match st {
80 Statement::Let { pat, type_ref: _, initializer, else_branch } => {
81 if let Some(i) = initializer {
82 self.infer_mut_expr(*i, self.pat_bound_mutability(*pat));
83 }
84 if let Some(e) = else_branch {
85 self.infer_mut_expr(*e, Mutability::Not);
86 }
87 }
88 Statement::Expr { expr, has_semi: _ } => {
89 self.infer_mut_expr(*expr, Mutability::Not);
90 }
91 Statement::Item(_) => (),
92 }
93 }
94 if let Some(tail) = tail {
95 self.infer_mut_expr(*tail, Mutability::Not);
96 }
97 }
98 Expr::MethodCall { receiver: it, method_name: _, args, generic_args: _ }
99 | Expr::Call { callee: it, args } => {
100 self.infer_mut_not_expr_iter(args.iter().copied().chain(Some(*it)));
101 }
102 Expr::Match { expr, arms } => {
103 let m = self.pat_iter_bound_mutability(arms.iter().map(|it| it.pat));
104 self.infer_mut_expr(*expr, m);
105 for arm in arms.iter() {
106 self.infer_mut_expr(arm.expr, Mutability::Not);
107 if let Some(g) = arm.guard {
108 self.infer_mut_expr(g, Mutability::Not);
109 }
110 }
111 }
112 Expr::Yield { expr }
113 | Expr::Yeet { expr }
114 | Expr::Return { expr }
115 | Expr::Break { expr, label: _ } => {
116 if let &Some(expr) = expr {
117 self.infer_mut_expr(expr, Mutability::Not);
118 }
119 }
120 Expr::Become { expr } => {
121 self.infer_mut_expr(*expr, Mutability::Not);
122 }
123 Expr::RecordLit { path: _, fields, spread } => {
124 self.infer_mut_not_expr_iter(fields.iter().map(|it| it.expr).chain(*spread))
125 }
126 &Expr::Index { base, index } => {
127 if mutability == Mutability::Mut {
128 if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
129 if let Some(index_trait) = self
130 .db
131 .lang_item(self.table.trait_env.krate, LangItem::IndexMut)
132 .and_then(|l| l.as_trait())
133 {
134 if let Some(index_fn) = self
135 .db
136 .trait_items(index_trait)
137 .method_by_name(&Name::new_symbol_root(sym::index_mut))
138 {
139 *f = index_fn;
140 let mut base_ty = None;
141 let base_adjustments = self
142 .result
143 .expr_adjustments
144 .get_mut(&base)
145 .and_then(|it| it.last_mut());
146 if let Some(Adjustment {
147 kind: Adjust::Borrow(AutoBorrow::Ref(_, mutability)),
148 target,
149 }) = base_adjustments
150 {
151 if let TyKind::Ref(_, _, ty) = target.kind(Interner) {
152 base_ty = Some(ty.clone());
153 }
154 *mutability = Mutability::Mut;
155 }
156
157 if let Some(base_ty) = base_ty {
159 let index_ty =
160 if let Some(ty) = self.result.type_of_expr.get(index) {
161 ty.clone()
162 } else {
163 self.infer_expr(
164 index,
165 &Expectation::none(),
166 ExprIsRead::Yes,
167 )
168 };
169 let trait_ref = TyBuilder::trait_ref(self.db, index_trait)
170 .push(base_ty)
171 .fill(|_| index_ty.clone().cast(Interner))
172 .build();
173 self.push_obligation(trait_ref.cast(Interner));
174 }
175 }
176 }
177 }
178 }
179 self.infer_mut_expr(base, mutability);
180 self.infer_mut_expr(index, Mutability::Not);
181 }
182 Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
183 let mut mutability = mutability;
184 if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
185 if mutability == Mutability::Mut {
186 if let Some(deref_trait) = self
187 .db
188 .lang_item(self.table.trait_env.krate, LangItem::DerefMut)
189 .and_then(|l| l.as_trait())
190 {
191 let ty = self.result.type_of_expr.get(*expr);
192 let is_mut_ptr = ty.is_some_and(|ty| {
193 let ty = self.table.resolve_ty_shallow(ty);
194 matches!(
195 ty.kind(Interner),
196 chalk_ir::TyKind::Raw(Mutability::Mut, _)
197 )
198 });
199 if is_mut_ptr {
200 mutability = Mutability::Not;
201 } else if let Some(deref_fn) = self
202 .db
203 .trait_items(deref_trait)
204 .method_by_name(&Name::new_symbol_root(sym::deref_mut))
205 {
206 *f = deref_fn;
207 }
208 }
209 }
210 }
211 self.infer_mut_expr(*expr, mutability);
212 }
213 Expr::Field { expr, name: _ } => {
214 self.infer_mut_expr(*expr, mutability);
215 }
216 Expr::UnaryOp { expr, op: _ }
217 | Expr::Range { lhs: Some(expr), rhs: None, range_type: _ }
218 | Expr::Range { rhs: Some(expr), lhs: None, range_type: _ }
219 | Expr::Await { expr }
220 | Expr::Box { expr }
221 | Expr::Loop { body: expr, label: _ }
222 | Expr::Cast { expr, type_ref: _ } => {
223 self.infer_mut_expr(*expr, Mutability::Not);
224 }
225 Expr::Ref { expr, rawness: _, mutability } => {
226 let mutability = lower_to_chalk_mutability(*mutability);
227 self.infer_mut_expr(*expr, mutability);
228 }
229 Expr::BinaryOp { lhs, rhs, op: Some(BinaryOp::Assignment { .. }) } => {
230 self.infer_mut_expr(*lhs, Mutability::Mut);
231 self.infer_mut_expr(*rhs, Mutability::Not);
232 }
233 &Expr::Assignment { target, value } => {
234 self.body.walk_pats(target, &mut |pat| match self.body[pat] {
235 Pat::Expr(expr) => self.infer_mut_expr(expr, Mutability::Mut),
236 Pat::ConstBlock(block) => self.infer_mut_expr(block, Mutability::Not),
237 _ => {}
238 });
239 self.infer_mut_expr(value, Mutability::Not);
240 }
241 Expr::Array(Array::Repeat { initializer: lhs, repeat: rhs })
242 | Expr::BinaryOp { lhs, rhs, op: _ }
243 | Expr::Range { lhs: Some(lhs), rhs: Some(rhs), range_type: _ } => {
244 self.infer_mut_expr(*lhs, Mutability::Not);
245 self.infer_mut_expr(*rhs, Mutability::Not);
246 }
247 Expr::Closure { body, .. } => {
248 self.infer_mut_expr(*body, Mutability::Not);
249 }
250 Expr::Tuple { exprs } | Expr::Array(Array::ElementList { elements: exprs }) => {
251 self.infer_mut_not_expr_iter(exprs.iter().copied());
252 }
253 Expr::Range { lhs: None, rhs: None, range_type: _ }
255 | Expr::Literal(_)
256 | Expr::Path(_)
257 | Expr::Continue { .. }
258 | Expr::Underscore => (),
259 }
260 }
261
262 fn infer_mut_not_expr_iter(&mut self, exprs: impl Iterator<Item = ExprId>) {
263 for expr in exprs {
264 self.infer_mut_expr(expr, Mutability::Not);
265 }
266 }
267
268 fn pat_iter_bound_mutability(&self, mut pat: impl Iterator<Item = PatId>) -> Mutability {
269 if pat.any(|p| self.pat_bound_mutability(p) == Mutability::Mut) {
270 Mutability::Mut
271 } else {
272 Mutability::Not
273 }
274 }
275
276 fn pat_bound_mutability(&self, pat: PatId) -> Mutability {
280 let mut r = Mutability::Not;
281 self.body.walk_bindings_in_pat(pat, |b| {
282 if self.body.bindings[b].mode == BindingAnnotation::RefMut {
283 r = Mutability::Mut;
284 }
285 });
286 r
287 }
288}