hir_ty/infer/
mutability.rs

1//! Finds if an expression is an immutable context or a mutable context, which is used in selecting
2//! between `Deref` and `DerefMut` or `Index` and `IndexMut` or similar.
3
4use chalk_ir::{Mutability, cast::Cast};
5use hir_def::{
6    hir::{
7        Array, AsmOperand, BinaryOp, BindingAnnotation, Expr, ExprId, Pat, PatId, Statement,
8        UnaryOp,
9    },
10    lang_item::LangItem,
11};
12use hir_expand::name::Name;
13use intern::sym;
14
15use crate::{
16    Adjust, Adjustment, AutoBorrow, Interner, OverloadedDeref, TyBuilder, TyKind,
17    infer::{Expectation, InferenceContext, expr::ExprIsRead},
18    lower::lower_to_chalk_mutability,
19};
20
21impl InferenceContext<'_> {
22    pub(crate) fn infer_mut_body(&mut self) {
23        self.infer_mut_expr(self.body.body_expr, Mutability::Not);
24    }
25
26    fn infer_mut_expr(&mut self, tgt_expr: ExprId, mut mutability: Mutability) {
27        if let Some(adjustments) = self.result.expr_adjustments.get_mut(&tgt_expr) {
28            for adj in adjustments.iter_mut().rev() {
29                match &mut adj.kind {
30                    Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => (),
31                    Adjust::Deref(Some(d)) => *d = OverloadedDeref(Some(mutability)),
32                    Adjust::Borrow(b) => match b {
33                        AutoBorrow::Ref(_, m) | AutoBorrow::RawPtr(m) => mutability = *m,
34                    },
35                }
36            }
37        }
38        self.infer_mut_expr_without_adjust(tgt_expr, mutability);
39    }
40
41    fn infer_mut_expr_without_adjust(&mut self, tgt_expr: ExprId, mutability: Mutability) {
42        match &self.body[tgt_expr] {
43            Expr::Missing => (),
44            Expr::InlineAsm(e) => {
45                e.operands.iter().for_each(|(_, op)| match op {
46                    AsmOperand::In { expr, .. }
47                    | AsmOperand::Out { expr: Some(expr), .. }
48                    | AsmOperand::InOut { expr, .. } => {
49                        self.infer_mut_expr_without_adjust(*expr, Mutability::Not)
50                    }
51                    AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
52                        self.infer_mut_expr_without_adjust(*in_expr, Mutability::Not);
53                        if let Some(out_expr) = out_expr {
54                            self.infer_mut_expr_without_adjust(*out_expr, Mutability::Not);
55                        }
56                    }
57                    AsmOperand::Out { expr: None, .. }
58                    | AsmOperand::Label(_)
59                    | AsmOperand::Sym(_)
60                    | AsmOperand::Const(_) => (),
61                });
62            }
63            Expr::OffsetOf(_) => (),
64            &Expr::If { condition, then_branch, else_branch } => {
65                self.infer_mut_expr(condition, Mutability::Not);
66                self.infer_mut_expr(then_branch, Mutability::Not);
67                if let Some(else_branch) = else_branch {
68                    self.infer_mut_expr(else_branch, Mutability::Not);
69                }
70            }
71            Expr::Const(id) => {
72                self.infer_mut_expr(*id, Mutability::Not);
73            }
74            Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)),
75            Expr::Block { id: _, statements, tail, label: _ }
76            | Expr::Async { id: _, statements, tail }
77            | Expr::Unsafe { id: _, statements, tail } => {
78                for st in statements.iter() {
79                    match st {
80                        Statement::Let { pat, type_ref: _, initializer, else_branch } => {
81                            if let Some(i) = initializer {
82                                self.infer_mut_expr(*i, self.pat_bound_mutability(*pat));
83                            }
84                            if let Some(e) = else_branch {
85                                self.infer_mut_expr(*e, Mutability::Not);
86                            }
87                        }
88                        Statement::Expr { expr, has_semi: _ } => {
89                            self.infer_mut_expr(*expr, Mutability::Not);
90                        }
91                        Statement::Item(_) => (),
92                    }
93                }
94                if let Some(tail) = tail {
95                    self.infer_mut_expr(*tail, Mutability::Not);
96                }
97            }
98            Expr::MethodCall { receiver: it, method_name: _, args, generic_args: _ }
99            | Expr::Call { callee: it, args } => {
100                self.infer_mut_not_expr_iter(args.iter().copied().chain(Some(*it)));
101            }
102            Expr::Match { expr, arms } => {
103                let m = self.pat_iter_bound_mutability(arms.iter().map(|it| it.pat));
104                self.infer_mut_expr(*expr, m);
105                for arm in arms.iter() {
106                    self.infer_mut_expr(arm.expr, Mutability::Not);
107                    if let Some(g) = arm.guard {
108                        self.infer_mut_expr(g, Mutability::Not);
109                    }
110                }
111            }
112            Expr::Yield { expr }
113            | Expr::Yeet { expr }
114            | Expr::Return { expr }
115            | Expr::Break { expr, label: _ } => {
116                if let &Some(expr) = expr {
117                    self.infer_mut_expr(expr, Mutability::Not);
118                }
119            }
120            Expr::Become { expr } => {
121                self.infer_mut_expr(*expr, Mutability::Not);
122            }
123            Expr::RecordLit { path: _, fields, spread } => {
124                self.infer_mut_not_expr_iter(fields.iter().map(|it| it.expr).chain(*spread))
125            }
126            &Expr::Index { base, index } => {
127                if mutability == Mutability::Mut {
128                    if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
129                        if let Some(index_trait) = self
130                            .db
131                            .lang_item(self.table.trait_env.krate, LangItem::IndexMut)
132                            .and_then(|l| l.as_trait())
133                        {
134                            if let Some(index_fn) = self
135                                .db
136                                .trait_items(index_trait)
137                                .method_by_name(&Name::new_symbol_root(sym::index_mut))
138                            {
139                                *f = index_fn;
140                                let mut base_ty = None;
141                                let base_adjustments = self
142                                    .result
143                                    .expr_adjustments
144                                    .get_mut(&base)
145                                    .and_then(|it| it.last_mut());
146                                if let Some(Adjustment {
147                                    kind: Adjust::Borrow(AutoBorrow::Ref(_, mutability)),
148                                    target,
149                                }) = base_adjustments
150                                {
151                                    if let TyKind::Ref(_, _, ty) = target.kind(Interner) {
152                                        base_ty = Some(ty.clone());
153                                    }
154                                    *mutability = Mutability::Mut;
155                                }
156
157                                // Apply `IndexMut` obligation for non-assignee expr
158                                if let Some(base_ty) = base_ty {
159                                    let index_ty =
160                                        if let Some(ty) = self.result.type_of_expr.get(index) {
161                                            ty.clone()
162                                        } else {
163                                            self.infer_expr(
164                                                index,
165                                                &Expectation::none(),
166                                                ExprIsRead::Yes,
167                                            )
168                                        };
169                                    let trait_ref = TyBuilder::trait_ref(self.db, index_trait)
170                                        .push(base_ty)
171                                        .fill(|_| index_ty.clone().cast(Interner))
172                                        .build();
173                                    self.push_obligation(trait_ref.cast(Interner));
174                                }
175                            }
176                        }
177                    }
178                }
179                self.infer_mut_expr(base, mutability);
180                self.infer_mut_expr(index, Mutability::Not);
181            }
182            Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
183                let mut mutability = mutability;
184                if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
185                    if mutability == Mutability::Mut {
186                        if let Some(deref_trait) = self
187                            .db
188                            .lang_item(self.table.trait_env.krate, LangItem::DerefMut)
189                            .and_then(|l| l.as_trait())
190                        {
191                            let ty = self.result.type_of_expr.get(*expr);
192                            let is_mut_ptr = ty.is_some_and(|ty| {
193                                let ty = self.table.resolve_ty_shallow(ty);
194                                matches!(
195                                    ty.kind(Interner),
196                                    chalk_ir::TyKind::Raw(Mutability::Mut, _)
197                                )
198                            });
199                            if is_mut_ptr {
200                                mutability = Mutability::Not;
201                            } else if let Some(deref_fn) = self
202                                .db
203                                .trait_items(deref_trait)
204                                .method_by_name(&Name::new_symbol_root(sym::deref_mut))
205                            {
206                                *f = deref_fn;
207                            }
208                        }
209                    }
210                }
211                self.infer_mut_expr(*expr, mutability);
212            }
213            Expr::Field { expr, name: _ } => {
214                self.infer_mut_expr(*expr, mutability);
215            }
216            Expr::UnaryOp { expr, op: _ }
217            | Expr::Range { lhs: Some(expr), rhs: None, range_type: _ }
218            | Expr::Range { rhs: Some(expr), lhs: None, range_type: _ }
219            | Expr::Await { expr }
220            | Expr::Box { expr }
221            | Expr::Loop { body: expr, label: _ }
222            | Expr::Cast { expr, type_ref: _ } => {
223                self.infer_mut_expr(*expr, Mutability::Not);
224            }
225            Expr::Ref { expr, rawness: _, mutability } => {
226                let mutability = lower_to_chalk_mutability(*mutability);
227                self.infer_mut_expr(*expr, mutability);
228            }
229            Expr::BinaryOp { lhs, rhs, op: Some(BinaryOp::Assignment { .. }) } => {
230                self.infer_mut_expr(*lhs, Mutability::Mut);
231                self.infer_mut_expr(*rhs, Mutability::Not);
232            }
233            &Expr::Assignment { target, value } => {
234                self.body.walk_pats(target, &mut |pat| match self.body[pat] {
235                    Pat::Expr(expr) => self.infer_mut_expr(expr, Mutability::Mut),
236                    Pat::ConstBlock(block) => self.infer_mut_expr(block, Mutability::Not),
237                    _ => {}
238                });
239                self.infer_mut_expr(value, Mutability::Not);
240            }
241            Expr::Array(Array::Repeat { initializer: lhs, repeat: rhs })
242            | Expr::BinaryOp { lhs, rhs, op: _ }
243            | Expr::Range { lhs: Some(lhs), rhs: Some(rhs), range_type: _ } => {
244                self.infer_mut_expr(*lhs, Mutability::Not);
245                self.infer_mut_expr(*rhs, Mutability::Not);
246            }
247            Expr::Closure { body, .. } => {
248                self.infer_mut_expr(*body, Mutability::Not);
249            }
250            Expr::Tuple { exprs } | Expr::Array(Array::ElementList { elements: exprs }) => {
251                self.infer_mut_not_expr_iter(exprs.iter().copied());
252            }
253            // These don't need any action, as they don't have sub expressions
254            Expr::Range { lhs: None, rhs: None, range_type: _ }
255            | Expr::Literal(_)
256            | Expr::Path(_)
257            | Expr::Continue { .. }
258            | Expr::Underscore => (),
259        }
260    }
261
262    fn infer_mut_not_expr_iter(&mut self, exprs: impl Iterator<Item = ExprId>) {
263        for expr in exprs {
264            self.infer_mut_expr(expr, Mutability::Not);
265        }
266    }
267
268    fn pat_iter_bound_mutability(&self, mut pat: impl Iterator<Item = PatId>) -> Mutability {
269        if pat.any(|p| self.pat_bound_mutability(p) == Mutability::Mut) {
270            Mutability::Mut
271        } else {
272            Mutability::Not
273        }
274    }
275
276    /// Checks if the pat contains a `ref mut` binding. Such paths makes the context of bounded expressions
277    /// mutable. For example in `let (ref mut x0, ref x1) = *it;` we need to use `DerefMut` for `*it` but in
278    /// `let (ref x0, ref x1) = *it;` we should use `Deref`.
279    fn pat_bound_mutability(&self, pat: PatId) -> Mutability {
280        let mut r = Mutability::Not;
281        self.body.walk_bindings_in_pat(pat, |b| {
282            if self.body.bindings[b].mode == BindingAnnotation::RefMut {
283                r = Mutability::Mut;
284            }
285        });
286        r
287    }
288}