hir_ty/infer/closure/
analysis.rs

1//! Post-inference closure analysis: captures and closure kind.
2
3use std::{cmp, convert::Infallible, mem};
4
5use either::Either;
6use hir_def::{
7    DefWithBodyId, FieldId, HasModule, TupleFieldId, TupleId, VariantId,
8    expr_store::path::Path,
9    hir::{
10        Array, AsmOperand, BinaryOp, BindingId, CaptureBy, Expr, ExprId, ExprOrPatId, Pat, PatId,
11        Statement, UnaryOp,
12    },
13    item_tree::FieldsShape,
14    resolver::ValueNs,
15};
16use rustc_ast_ir::Mutability;
17use rustc_hash::{FxHashMap, FxHashSet};
18use rustc_type_ir::inherent::{IntoKind, SliceLike, Ty as _};
19use smallvec::{SmallVec, smallvec};
20use stdx::{format_to, never};
21use syntax::utils::is_raw_identifier;
22
23use crate::{
24    Adjust, Adjustment, BindingMode,
25    db::{HirDatabase, InternedClosure, InternedClosureId},
26    infer::InferenceContext,
27    mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
28    next_solver::{DbInterner, EarlyBinder, GenericArgs, Ty, TyKind},
29    traits::FnTrait,
30};
31
32// The below functions handle capture and closure kind (Fn, FnMut, ..)
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, salsa::Update)]
35pub(crate) struct HirPlace<'db> {
36    pub(crate) local: BindingId,
37    pub(crate) projections: Vec<ProjectionElem<'db, Infallible>>,
38}
39
40impl<'db> HirPlace<'db> {
41    fn ty(&self, ctx: &mut InferenceContext<'_, 'db>) -> Ty<'db> {
42        let mut ty = ctx.table.resolve_completely(ctx.result[self.local]);
43        for p in &self.projections {
44            ty = p.projected_ty(
45                &ctx.table.infer_ctxt,
46                ctx.table.param_env,
47                ty,
48                |_, _, _| {
49                    unreachable!("Closure field only happens in MIR");
50                },
51                ctx.owner.module(ctx.db).krate(ctx.db),
52            );
53        }
54        ty
55    }
56
57    fn capture_kind_of_truncated_place(
58        &self,
59        mut current_capture: CaptureKind,
60        len: usize,
61    ) -> CaptureKind {
62        if let CaptureKind::ByRef(BorrowKind::Mut {
63            kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow,
64        }) = current_capture
65            && self.projections[len..].contains(&ProjectionElem::Deref)
66        {
67            current_capture =
68                CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture });
69        }
70        current_capture
71    }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
75pub enum CaptureKind {
76    ByRef(BorrowKind),
77    ByValue,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, salsa::Update)]
81pub struct CapturedItem<'db> {
82    pub(crate) place: HirPlace<'db>,
83    pub(crate) kind: CaptureKind,
84    /// The inner vec is the stacks; the outer vec is for each capture reference.
85    ///
86    /// Even though we always report only the last span (i.e. the most inclusive span),
87    /// we need to keep them all, since when a closure occurs inside a closure, we
88    /// copy all captures of the inner closure to the outer closure, and then we may
89    /// truncate them, and we want the correct span to be reported.
90    span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>,
91    #[update(unsafe(with(crate::utils::unsafe_update_eq)))]
92    pub(crate) ty: EarlyBinder<'db, Ty<'db>>,
93}
94
95impl<'db> CapturedItem<'db> {
96    pub fn local(&self) -> BindingId {
97        self.place.local
98    }
99
100    /// Returns whether this place has any field (aka. non-deref) projections.
101    pub fn has_field_projections(&self) -> bool {
102        self.place.projections.iter().any(|it| !matches!(it, ProjectionElem::Deref))
103    }
104
105    pub fn ty(&self, db: &'db dyn HirDatabase, subst: GenericArgs<'db>) -> Ty<'db> {
106        let interner = DbInterner::new_no_crate(db);
107        self.ty.instantiate(interner, subst.split_closure_args_untupled().parent_args)
108    }
109
110    pub fn kind(&self) -> CaptureKind {
111        self.kind
112    }
113
114    pub fn spans(&self) -> SmallVec<[MirSpan; 3]> {
115        self.span_stacks.iter().map(|stack| *stack.last().expect("empty span stack")).collect()
116    }
117
118    /// Converts the place to a name that can be inserted into source code.
119    pub fn place_to_name(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
120        let body = db.body(owner);
121        let mut result = body[self.place.local].name.as_str().to_owned();
122        for proj in &self.place.projections {
123            match proj {
124                ProjectionElem::Deref => {}
125                ProjectionElem::Field(Either::Left(f)) => {
126                    let variant_data = f.parent.fields(db);
127                    match variant_data.shape {
128                        FieldsShape::Record => {
129                            result.push('_');
130                            result.push_str(variant_data.fields()[f.local_id].name.as_str())
131                        }
132                        FieldsShape::Tuple => {
133                            let index =
134                                variant_data.fields().iter().position(|it| it.0 == f.local_id);
135                            if let Some(index) = index {
136                                format_to!(result, "_{index}");
137                            }
138                        }
139                        FieldsShape::Unit => {}
140                    }
141                }
142                ProjectionElem::Field(Either::Right(f)) => format_to!(result, "_{}", f.index),
143                &ProjectionElem::ClosureField(field) => format_to!(result, "_{field}"),
144                ProjectionElem::Index(_)
145                | ProjectionElem::ConstantIndex { .. }
146                | ProjectionElem::Subslice { .. }
147                | ProjectionElem::OpaqueCast(_) => {
148                    never!("Not happen in closure capture");
149                    continue;
150                }
151            }
152        }
153        if is_raw_identifier(&result, owner.module(db).krate(db).data(db).edition) {
154            result.insert_str(0, "r#");
155        }
156        result
157    }
158
159    pub fn display_place_source_code(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
160        let body = db.body(owner);
161        let krate = owner.krate(db);
162        let edition = krate.data(db).edition;
163        let mut result = body[self.place.local].name.display(db, edition).to_string();
164        for proj in &self.place.projections {
165            match proj {
166                // In source code autoderef kicks in.
167                ProjectionElem::Deref => {}
168                ProjectionElem::Field(Either::Left(f)) => {
169                    let variant_data = f.parent.fields(db);
170                    match variant_data.shape {
171                        FieldsShape::Record => format_to!(
172                            result,
173                            ".{}",
174                            variant_data.fields()[f.local_id].name.display(db, edition)
175                        ),
176                        FieldsShape::Tuple => format_to!(
177                            result,
178                            ".{}",
179                            variant_data
180                                .fields()
181                                .iter()
182                                .position(|it| it.0 == f.local_id)
183                                .unwrap_or_default()
184                        ),
185                        FieldsShape::Unit => {}
186                    }
187                }
188                ProjectionElem::Field(Either::Right(f)) => {
189                    let field = f.index;
190                    format_to!(result, ".{field}");
191                }
192                &ProjectionElem::ClosureField(field) => {
193                    format_to!(result, ".{field}");
194                }
195                ProjectionElem::Index(_)
196                | ProjectionElem::ConstantIndex { .. }
197                | ProjectionElem::Subslice { .. }
198                | ProjectionElem::OpaqueCast(_) => {
199                    never!("Not happen in closure capture");
200                    continue;
201                }
202            }
203        }
204        let final_derefs_count = self
205            .place
206            .projections
207            .iter()
208            .rev()
209            .take_while(|proj| matches!(proj, ProjectionElem::Deref))
210            .count();
211        result.insert_str(0, &"*".repeat(final_derefs_count));
212        result
213    }
214
215    pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String {
216        let body = db.body(owner);
217        let krate = owner.krate(db);
218        let edition = krate.data(db).edition;
219        let mut result = body[self.place.local].name.display(db, edition).to_string();
220        let mut field_need_paren = false;
221        for proj in &self.place.projections {
222            match proj {
223                ProjectionElem::Deref => {
224                    result = format!("*{result}");
225                    field_need_paren = true;
226                }
227                ProjectionElem::Field(Either::Left(f)) => {
228                    if field_need_paren {
229                        result = format!("({result})");
230                    }
231                    let variant_data = f.parent.fields(db);
232                    let field = match variant_data.shape {
233                        FieldsShape::Record => {
234                            variant_data.fields()[f.local_id].name.as_str().to_owned()
235                        }
236                        FieldsShape::Tuple => variant_data
237                            .fields()
238                            .iter()
239                            .position(|it| it.0 == f.local_id)
240                            .unwrap_or_default()
241                            .to_string(),
242                        FieldsShape::Unit => "[missing field]".to_owned(),
243                    };
244                    result = format!("{result}.{field}");
245                    field_need_paren = false;
246                }
247                ProjectionElem::Field(Either::Right(f)) => {
248                    let field = f.index;
249                    if field_need_paren {
250                        result = format!("({result})");
251                    }
252                    result = format!("{result}.{field}");
253                    field_need_paren = false;
254                }
255                &ProjectionElem::ClosureField(field) => {
256                    if field_need_paren {
257                        result = format!("({result})");
258                    }
259                    result = format!("{result}.{field}");
260                    field_need_paren = false;
261                }
262                ProjectionElem::Index(_)
263                | ProjectionElem::ConstantIndex { .. }
264                | ProjectionElem::Subslice { .. }
265                | ProjectionElem::OpaqueCast(_) => {
266                    never!("Not happen in closure capture");
267                    continue;
268                }
269            }
270        }
271        result
272    }
273}
274
275#[derive(Debug, Clone, PartialEq, Eq)]
276pub(crate) struct CapturedItemWithoutTy<'db> {
277    pub(crate) place: HirPlace<'db>,
278    pub(crate) kind: CaptureKind,
279    /// The inner vec is the stacks; the outer vec is for each capture reference.
280    pub(crate) span_stacks: SmallVec<[SmallVec<[MirSpan; 3]>; 3]>,
281}
282
283impl<'db> CapturedItemWithoutTy<'db> {
284    fn with_ty(self, ctx: &mut InferenceContext<'_, 'db>) -> CapturedItem<'db> {
285        let ty = self.place.ty(ctx);
286        let ty = match &self.kind {
287            CaptureKind::ByValue => ty,
288            CaptureKind::ByRef(bk) => {
289                let m = match bk {
290                    BorrowKind::Mut { .. } => Mutability::Mut,
291                    _ => Mutability::Not,
292                };
293                Ty::new_ref(ctx.interner(), ctx.types.re_error, ty, m)
294            }
295        };
296        CapturedItem {
297            place: self.place,
298            kind: self.kind,
299            span_stacks: self.span_stacks,
300            ty: EarlyBinder::bind(ty),
301        }
302    }
303}
304
305impl<'db> InferenceContext<'_, 'db> {
306    fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option<HirPlace<'db>> {
307        let r = self.place_of_expr_without_adjust(tgt_expr)?;
308        let adjustments =
309            self.result.expr_adjustments.get(&tgt_expr).map(|it| &**it).unwrap_or_default();
310        apply_adjusts_to_place(&mut self.current_capture_span_stack, r, adjustments)
311    }
312
313    /// Pushes the span into `current_capture_span_stack`, *without clearing it first*.
314    fn path_place(&mut self, path: &Path, id: ExprOrPatId) -> Option<HirPlace<'db>> {
315        if path.type_anchor().is_some() {
316            return None;
317        }
318        let hygiene = self.body.expr_or_pat_path_hygiene(id);
319        self.resolver.resolve_path_in_value_ns_fully(self.db, path, hygiene).and_then(|result| {
320            match result {
321                ValueNs::LocalBinding(binding) => {
322                    let mir_span = match id {
323                        ExprOrPatId::ExprId(id) => MirSpan::ExprId(id),
324                        ExprOrPatId::PatId(id) => MirSpan::PatId(id),
325                    };
326                    self.current_capture_span_stack.push(mir_span);
327                    Some(HirPlace { local: binding, projections: Vec::new() })
328                }
329                _ => None,
330            }
331        })
332    }
333
334    /// Changes `current_capture_span_stack` to contain the stack of spans for this expr.
335    fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option<HirPlace<'db>> {
336        self.current_capture_span_stack.clear();
337        match &self.body[tgt_expr] {
338            Expr::Path(p) => {
339                let resolver_guard =
340                    self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr);
341                let result = self.path_place(p, tgt_expr.into());
342                self.resolver.reset_to_guard(resolver_guard);
343                return result;
344            }
345            Expr::Field { expr, name: _ } => {
346                let mut place = self.place_of_expr(*expr)?;
347                let field = self.result.field_resolution(tgt_expr)?;
348                self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr));
349                place.projections.push(ProjectionElem::Field(field));
350                return Some(place);
351            }
352            Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
353                let is_builtin_deref = match self.expr_ty(*expr).kind() {
354                    TyKind::Ref(..) | TyKind::RawPtr(..) => true,
355                    TyKind::Adt(adt_def, _) if adt_def.is_box() => true,
356                    _ => false,
357                };
358                if is_builtin_deref {
359                    let mut place = self.place_of_expr(*expr)?;
360                    self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr));
361                    place.projections.push(ProjectionElem::Deref);
362                    return Some(place);
363                }
364            }
365            _ => (),
366        }
367        None
368    }
369
370    fn push_capture(&mut self, place: HirPlace<'db>, kind: CaptureKind) {
371        self.current_captures.push(CapturedItemWithoutTy {
372            place,
373            kind,
374            span_stacks: smallvec![self.current_capture_span_stack.iter().copied().collect()],
375        });
376    }
377
378    fn truncate_capture_spans(
379        &self,
380        capture: &mut CapturedItemWithoutTy<'db>,
381        mut truncate_to: usize,
382    ) {
383        // The first span is the identifier, and it must always remain.
384        truncate_to += 1;
385        for span_stack in &mut capture.span_stacks {
386            let mut remained = truncate_to;
387            let mut actual_truncate_to = 0;
388            for &span in &*span_stack {
389                actual_truncate_to += 1;
390                if !span.is_ref_span(self.body) {
391                    remained -= 1;
392                    if remained == 0 {
393                        break;
394                    }
395                }
396            }
397            if actual_truncate_to < span_stack.len()
398                && span_stack[actual_truncate_to].is_ref_span(self.body)
399            {
400                // Include the ref operator if there is one, we will fix it later (in `strip_captures_ref_span()`) if it's incorrect.
401                actual_truncate_to += 1;
402            }
403            span_stack.truncate(actual_truncate_to);
404        }
405    }
406
407    fn ref_expr(&mut self, expr: ExprId, place: Option<HirPlace<'db>>) {
408        if let Some(place) = place {
409            self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared));
410        }
411        self.walk_expr(expr);
412    }
413
414    fn add_capture(&mut self, place: HirPlace<'db>, kind: CaptureKind) {
415        if self.is_upvar(&place) {
416            self.push_capture(place, kind);
417        }
418    }
419
420    fn mutate_path_pat(&mut self, path: &Path, id: PatId) {
421        if let Some(place) = self.path_place(path, id.into()) {
422            self.add_capture(
423                place,
424                CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }),
425            );
426            self.current_capture_span_stack.pop(); // Remove the pattern span.
427        }
428    }
429
430    fn mutate_expr(&mut self, expr: ExprId, place: Option<HirPlace<'db>>) {
431        if let Some(place) = place {
432            self.add_capture(
433                place,
434                CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }),
435            );
436        }
437        self.walk_expr(expr);
438    }
439
440    fn consume_expr(&mut self, expr: ExprId) {
441        if let Some(place) = self.place_of_expr(expr) {
442            self.consume_place(place);
443        }
444        self.walk_expr(expr);
445    }
446
447    fn consume_place(&mut self, place: HirPlace<'db>) {
448        if self.is_upvar(&place) {
449            let ty = place.ty(self);
450            let kind = if self.is_ty_copy(ty) {
451                CaptureKind::ByRef(BorrowKind::Shared)
452            } else {
453                CaptureKind::ByValue
454            };
455            self.push_capture(place, kind);
456        }
457    }
458
459    fn walk_expr_with_adjust(&mut self, tgt_expr: ExprId, adjustment: &[Adjustment<'db>]) {
460        if let Some((last, rest)) = adjustment.split_last() {
461            match &last.kind {
462                Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => {
463                    self.walk_expr_with_adjust(tgt_expr, rest)
464                }
465                Adjust::Deref(Some(m)) => match m.0 {
466                    Some(m) => {
467                        self.ref_capture_with_adjusts(m, tgt_expr, rest);
468                    }
469                    None => unreachable!(),
470                },
471                Adjust::Borrow(b) => {
472                    self.ref_capture_with_adjusts(b.mutability(), tgt_expr, rest);
473                }
474            }
475        } else {
476            self.walk_expr_without_adjust(tgt_expr);
477        }
478    }
479
480    fn ref_capture_with_adjusts(
481        &mut self,
482        m: Mutability,
483        tgt_expr: ExprId,
484        rest: &[Adjustment<'db>],
485    ) {
486        let capture_kind = match m {
487            Mutability::Mut => CaptureKind::ByRef(BorrowKind::Mut { kind: MutBorrowKind::Default }),
488            Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared),
489        };
490        if let Some(place) = self.place_of_expr_without_adjust(tgt_expr)
491            && let Some(place) =
492                apply_adjusts_to_place(&mut self.current_capture_span_stack, place, rest)
493        {
494            self.add_capture(place, capture_kind);
495        }
496        self.walk_expr_with_adjust(tgt_expr, rest);
497    }
498
499    fn walk_expr(&mut self, tgt_expr: ExprId) {
500        if let Some(it) = self.result.expr_adjustments.get_mut(&tgt_expr) {
501            // FIXME: this take is completely unneeded, and just is here to make borrow checker
502            // happy. Remove it if you can.
503            let x_taken = mem::take(it);
504            self.walk_expr_with_adjust(tgt_expr, &x_taken);
505            *self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken;
506        } else {
507            self.walk_expr_without_adjust(tgt_expr);
508        }
509    }
510
511    fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) {
512        match &self.body[tgt_expr] {
513            Expr::OffsetOf(_) => (),
514            Expr::InlineAsm(e) => e.operands.iter().for_each(|(_, op)| match op {
515                AsmOperand::In { expr, .. }
516                | AsmOperand::Out { expr: Some(expr), .. }
517                | AsmOperand::InOut { expr, .. } => self.walk_expr_without_adjust(*expr),
518                AsmOperand::SplitInOut { in_expr, out_expr, .. } => {
519                    self.walk_expr_without_adjust(*in_expr);
520                    if let Some(out_expr) = out_expr {
521                        self.walk_expr_without_adjust(*out_expr);
522                    }
523                }
524                AsmOperand::Out { expr: None, .. }
525                | AsmOperand::Const(_)
526                | AsmOperand::Label(_)
527                | AsmOperand::Sym(_) => (),
528            }),
529            Expr::If { condition, then_branch, else_branch } => {
530                self.consume_expr(*condition);
531                self.consume_expr(*then_branch);
532                if let &Some(expr) = else_branch {
533                    self.consume_expr(expr);
534                }
535            }
536            Expr::Async { statements, tail, .. }
537            | Expr::Unsafe { statements, tail, .. }
538            | Expr::Block { statements, tail, .. } => {
539                for s in statements.iter() {
540                    match s {
541                        Statement::Let { pat, type_ref: _, initializer, else_branch } => {
542                            if let Some(else_branch) = else_branch {
543                                self.consume_expr(*else_branch);
544                            }
545                            if let Some(initializer) = initializer {
546                                if else_branch.is_some() {
547                                    self.consume_expr(*initializer);
548                                } else {
549                                    self.walk_expr(*initializer);
550                                }
551                                if let Some(place) = self.place_of_expr(*initializer) {
552                                    self.consume_with_pat(place, *pat);
553                                }
554                            }
555                        }
556                        Statement::Expr { expr, has_semi: _ } => {
557                            self.consume_expr(*expr);
558                        }
559                        Statement::Item(_) => (),
560                    }
561                }
562                if let Some(tail) = tail {
563                    self.consume_expr(*tail);
564                }
565            }
566            Expr::Call { callee, args } => {
567                self.consume_expr(*callee);
568                self.consume_exprs(args.iter().copied());
569            }
570            Expr::MethodCall { receiver, args, .. } => {
571                self.consume_expr(*receiver);
572                self.consume_exprs(args.iter().copied());
573            }
574            Expr::Match { expr, arms } => {
575                for arm in arms.iter() {
576                    self.consume_expr(arm.expr);
577                    if let Some(guard) = arm.guard {
578                        self.consume_expr(guard);
579                    }
580                }
581                self.walk_expr(*expr);
582                if let Some(discr_place) = self.place_of_expr(*expr)
583                    && self.is_upvar(&discr_place)
584                {
585                    let mut capture_mode = None;
586                    for arm in arms.iter() {
587                        self.walk_pat(&mut capture_mode, arm.pat);
588                    }
589                    if let Some(c) = capture_mode {
590                        self.push_capture(discr_place, c);
591                    }
592                }
593            }
594            Expr::Break { expr, label: _ }
595            | Expr::Return { expr }
596            | Expr::Yield { expr }
597            | Expr::Yeet { expr } => {
598                if let &Some(expr) = expr {
599                    self.consume_expr(expr);
600                }
601            }
602            &Expr::Become { expr } => {
603                self.consume_expr(expr);
604            }
605            Expr::RecordLit { fields, spread, .. } => {
606                if let &Some(expr) = spread {
607                    self.consume_expr(expr);
608                }
609                self.consume_exprs(fields.iter().map(|it| it.expr));
610            }
611            Expr::Field { expr, name: _ } => self.select_from_expr(*expr),
612            Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
613                if self.result.method_resolution(tgt_expr).is_some() {
614                    // Overloaded deref.
615                    match self.expr_ty_after_adjustments(*expr).kind() {
616                        TyKind::Ref(_, _, mutability) => {
617                            let place = self.place_of_expr(*expr);
618                            match mutability {
619                                Mutability::Mut => self.mutate_expr(*expr, place),
620                                Mutability::Not => self.ref_expr(*expr, place),
621                            }
622                        }
623                        // FIXME: Is this correct wrt. raw pointer derefs?
624                        TyKind::RawPtr(..) => self.select_from_expr(*expr),
625                        _ => never!("deref adjustments should include taking a mutable reference"),
626                    }
627                } else {
628                    self.select_from_expr(*expr);
629                }
630            }
631            Expr::Let { pat, expr } => {
632                self.walk_expr(*expr);
633                if let Some(place) = self.place_of_expr(*expr) {
634                    self.consume_with_pat(place, *pat);
635                }
636            }
637            Expr::UnaryOp { expr, op: _ }
638            | Expr::Array(Array::Repeat { initializer: expr, repeat: _ })
639            | Expr::Await { expr }
640            | Expr::Loop { body: expr, label: _ }
641            | Expr::Box { expr }
642            | Expr::Cast { expr, type_ref: _ } => {
643                self.consume_expr(*expr);
644            }
645            Expr::Ref { expr, rawness: _, mutability } => {
646                // We need to do this before we push the span so the order will be correct.
647                let place = self.place_of_expr(*expr);
648                self.current_capture_span_stack.push(MirSpan::ExprId(tgt_expr));
649                match mutability {
650                    hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr, place),
651                    hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr, place),
652                }
653            }
654            Expr::BinaryOp { lhs, rhs, op } => {
655                let Some(op) = op else {
656                    return;
657                };
658                if matches!(op, BinaryOp::Assignment { .. }) {
659                    let place = self.place_of_expr(*lhs);
660                    self.mutate_expr(*lhs, place);
661                    self.consume_expr(*rhs);
662                    return;
663                }
664                self.consume_expr(*lhs);
665                self.consume_expr(*rhs);
666            }
667            Expr::Range { lhs, rhs, range_type: _ } => {
668                if let &Some(expr) = lhs {
669                    self.consume_expr(expr);
670                }
671                if let &Some(expr) = rhs {
672                    self.consume_expr(expr);
673                }
674            }
675            Expr::Index { base, index } => {
676                self.select_from_expr(*base);
677                self.consume_expr(*index);
678            }
679            Expr::Closure { .. } => {
680                let ty = self.expr_ty(tgt_expr);
681                let TyKind::Closure(id, _) = ty.kind() else {
682                    never!("closure type is always closure");
683                    return;
684                };
685                let (captures, _) =
686                    self.result.closure_info.get(&id.0).expect(
687                        "We sort closures, so we should always have data for inner closures",
688                    );
689                let mut cc = mem::take(&mut self.current_captures);
690                cc.extend(captures.iter().filter(|it| self.is_upvar(&it.place)).map(|it| {
691                    CapturedItemWithoutTy {
692                        place: it.place.clone(),
693                        kind: it.kind,
694                        span_stacks: it.span_stacks.clone(),
695                    }
696                }));
697                self.current_captures = cc;
698            }
699            Expr::Array(Array::ElementList { elements: exprs }) | Expr::Tuple { exprs } => {
700                self.consume_exprs(exprs.iter().copied())
701            }
702            &Expr::Assignment { target, value } => {
703                self.walk_expr(value);
704                let resolver_guard =
705                    self.resolver.update_to_inner_scope(self.db, self.owner, tgt_expr);
706                match self.place_of_expr(value) {
707                    Some(rhs_place) => {
708                        self.inside_assignment = true;
709                        self.consume_with_pat(rhs_place, target);
710                        self.inside_assignment = false;
711                    }
712                    None => self.body.walk_pats(target, &mut |pat| match &self.body[pat] {
713                        Pat::Path(path) => self.mutate_path_pat(path, pat),
714                        &Pat::Expr(expr) => {
715                            let place = self.place_of_expr(expr);
716                            self.mutate_expr(expr, place);
717                        }
718                        _ => {}
719                    }),
720                }
721                self.resolver.reset_to_guard(resolver_guard);
722            }
723
724            Expr::Missing
725            | Expr::Continue { .. }
726            | Expr::Path(_)
727            | Expr::Literal(_)
728            | Expr::Const(_)
729            | Expr::Underscore => (),
730        }
731    }
732
733    fn walk_pat(&mut self, result: &mut Option<CaptureKind>, pat: PatId) {
734        let mut update_result = |ck: CaptureKind| match result {
735            Some(r) => {
736                *r = cmp::max(*r, ck);
737            }
738            None => *result = Some(ck),
739        };
740
741        self.walk_pat_inner(
742            pat,
743            &mut update_result,
744            BorrowKind::Mut { kind: MutBorrowKind::Default },
745        );
746    }
747
748    fn walk_pat_inner(
749        &mut self,
750        p: PatId,
751        update_result: &mut impl FnMut(CaptureKind),
752        mut for_mut: BorrowKind,
753    ) {
754        match &self.body[p] {
755            Pat::Ref { .. }
756            | Pat::Box { .. }
757            | Pat::Missing
758            | Pat::Wild
759            | Pat::Tuple { .. }
760            | Pat::Expr(_)
761            | Pat::Or(_) => (),
762            Pat::TupleStruct { .. } | Pat::Record { .. } => {
763                if let Some(variant) = self.result.variant_resolution_for_pat(p) {
764                    let adt = variant.adt_id(self.db);
765                    let is_multivariant = match adt {
766                        hir_def::AdtId::EnumId(e) => e.enum_variants(self.db).variants.len() != 1,
767                        _ => false,
768                    };
769                    if is_multivariant {
770                        update_result(CaptureKind::ByRef(BorrowKind::Shared));
771                    }
772                }
773            }
774            Pat::Slice { .. }
775            | Pat::ConstBlock(_)
776            | Pat::Path(_)
777            | Pat::Lit(_)
778            | Pat::Range { .. } => {
779                update_result(CaptureKind::ByRef(BorrowKind::Shared));
780            }
781            Pat::Bind { id, .. } => match self.result.binding_modes[p] {
782                crate::BindingMode::Move => {
783                    if self.is_ty_copy(self.result.type_of_binding[*id]) {
784                        update_result(CaptureKind::ByRef(BorrowKind::Shared));
785                    } else {
786                        update_result(CaptureKind::ByValue);
787                    }
788                }
789                crate::BindingMode::Ref(r) => match r {
790                    Mutability::Mut => update_result(CaptureKind::ByRef(for_mut)),
791                    Mutability::Not => update_result(CaptureKind::ByRef(BorrowKind::Shared)),
792                },
793            },
794        }
795        if self.result.pat_adjustments.get(&p).is_some_and(|it| !it.is_empty()) {
796            for_mut = BorrowKind::Mut { kind: MutBorrowKind::ClosureCapture };
797        }
798        self.body.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut));
799    }
800
801    fn is_upvar(&self, place: &HirPlace<'db>) -> bool {
802        if let Some(c) = self.current_closure {
803            let InternedClosure(_, root) = self.db.lookup_intern_closure(c);
804            return self.body.is_binding_upvar(place.local, root);
805        }
806        false
807    }
808
809    fn is_ty_copy(&mut self, ty: Ty<'db>) -> bool {
810        if let TyKind::Closure(id, _) = ty.kind() {
811            // FIXME: We handle closure as a special case, since chalk consider every closure as copy. We
812            // should probably let chalk know which closures are copy, but I don't know how doing it
813            // without creating query cycles.
814            return self
815                .result
816                .closure_info
817                .get(&id.0)
818                .map(|it| it.1 == FnTrait::Fn)
819                .unwrap_or(true);
820        }
821        let ty = self.table.resolve_completely(ty);
822        self.table.type_is_copy_modulo_regions(ty)
823    }
824
825    fn select_from_expr(&mut self, expr: ExprId) {
826        self.walk_expr(expr);
827    }
828
829    fn restrict_precision_for_unsafe(&mut self) {
830        // FIXME: Borrow checker problems without this.
831        let mut current_captures = std::mem::take(&mut self.current_captures);
832        for capture in &mut current_captures {
833            let mut ty = self.table.resolve_completely(self.result[capture.place.local]);
834            if ty.is_raw_ptr() || ty.is_union() {
835                capture.kind = CaptureKind::ByRef(BorrowKind::Shared);
836                self.truncate_capture_spans(capture, 0);
837                capture.place.projections.truncate(0);
838                continue;
839            }
840            for (i, p) in capture.place.projections.iter().enumerate() {
841                ty = p.projected_ty(
842                    &self.table.infer_ctxt,
843                    self.table.param_env,
844                    ty,
845                    |_, _, _| {
846                        unreachable!("Closure field only happens in MIR");
847                    },
848                    self.owner.module(self.db).krate(self.db),
849                );
850                if ty.is_raw_ptr() || ty.is_union() {
851                    capture.kind = CaptureKind::ByRef(BorrowKind::Shared);
852                    self.truncate_capture_spans(capture, i + 1);
853                    capture.place.projections.truncate(i + 1);
854                    break;
855                }
856            }
857        }
858        self.current_captures = current_captures;
859    }
860
861    fn adjust_for_move_closure(&mut self) {
862        // FIXME: Borrow checker won't allow without this.
863        let mut current_captures = std::mem::take(&mut self.current_captures);
864        for capture in &mut current_captures {
865            if let Some(first_deref) =
866                capture.place.projections.iter().position(|proj| *proj == ProjectionElem::Deref)
867            {
868                self.truncate_capture_spans(capture, first_deref);
869                capture.place.projections.truncate(first_deref);
870            }
871            capture.kind = CaptureKind::ByValue;
872        }
873        self.current_captures = current_captures;
874    }
875
876    fn minimize_captures(&mut self) {
877        self.current_captures.sort_unstable_by_key(|it| it.place.projections.len());
878        let mut hash_map = FxHashMap::<HirPlace<'db>, usize>::default();
879        let result = mem::take(&mut self.current_captures);
880        for mut item in result {
881            let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] };
882            let mut it = item.place.projections.iter();
883            let prev_index = loop {
884                if let Some(k) = hash_map.get(&lookup_place) {
885                    break Some(*k);
886                }
887                match it.next() {
888                    Some(it) => {
889                        lookup_place.projections.push(it.clone());
890                    }
891                    None => break None,
892                }
893            };
894            match prev_index {
895                Some(p) => {
896                    let prev_projections_len = self.current_captures[p].place.projections.len();
897                    self.truncate_capture_spans(&mut item, prev_projections_len);
898                    self.current_captures[p].span_stacks.extend(item.span_stacks);
899                    let len = self.current_captures[p].place.projections.len();
900                    let kind_after_truncate =
901                        item.place.capture_kind_of_truncated_place(item.kind, len);
902                    self.current_captures[p].kind =
903                        cmp::max(kind_after_truncate, self.current_captures[p].kind);
904                }
905                None => {
906                    hash_map.insert(item.place.clone(), self.current_captures.len());
907                    self.current_captures.push(item);
908                }
909            }
910        }
911    }
912
913    fn consume_with_pat(&mut self, mut place: HirPlace<'db>, tgt_pat: PatId) {
914        let adjustments_count =
915            self.result.pat_adjustments.get(&tgt_pat).map(|it| it.len()).unwrap_or_default();
916        place.projections.extend((0..adjustments_count).map(|_| ProjectionElem::Deref));
917        self.current_capture_span_stack
918            .extend((0..adjustments_count).map(|_| MirSpan::PatId(tgt_pat)));
919        'reset_span_stack: {
920            match &self.body[tgt_pat] {
921                Pat::Missing | Pat::Wild => (),
922                Pat::Tuple { args, ellipsis } => {
923                    let (al, ar) = args.split_at(ellipsis.map_or(args.len(), |it| it as usize));
924                    let field_count = match self.result[tgt_pat].kind() {
925                        TyKind::Tuple(s) => s.len(),
926                        _ => break 'reset_span_stack,
927                    };
928                    let fields = 0..field_count;
929                    let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev()));
930                    for (&arg, i) in it {
931                        let mut p = place.clone();
932                        self.current_capture_span_stack.push(MirSpan::PatId(arg));
933                        p.projections.push(ProjectionElem::Field(Either::Right(TupleFieldId {
934                            tuple: TupleId(!0), // dummy this, as its unused anyways
935                            index: i as u32,
936                        })));
937                        self.consume_with_pat(p, arg);
938                        self.current_capture_span_stack.pop();
939                    }
940                }
941                Pat::Or(pats) => {
942                    for pat in pats.iter() {
943                        self.consume_with_pat(place.clone(), *pat);
944                    }
945                }
946                Pat::Record { args, .. } => {
947                    let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else {
948                        break 'reset_span_stack;
949                    };
950                    match variant {
951                        VariantId::EnumVariantId(_) | VariantId::UnionId(_) => {
952                            self.consume_place(place)
953                        }
954                        VariantId::StructId(s) => {
955                            let vd = s.fields(self.db);
956                            for field_pat in args.iter() {
957                                let arg = field_pat.pat;
958                                let Some(local_id) = vd.field(&field_pat.name) else {
959                                    continue;
960                                };
961                                let mut p = place.clone();
962                                self.current_capture_span_stack.push(MirSpan::PatId(arg));
963                                p.projections.push(ProjectionElem::Field(Either::Left(FieldId {
964                                    parent: variant,
965                                    local_id,
966                                })));
967                                self.consume_with_pat(p, arg);
968                                self.current_capture_span_stack.pop();
969                            }
970                        }
971                    }
972                }
973                Pat::Range { .. } | Pat::Slice { .. } | Pat::ConstBlock(_) | Pat::Lit(_) => {
974                    self.consume_place(place)
975                }
976                Pat::Path(path) => {
977                    if self.inside_assignment {
978                        self.mutate_path_pat(path, tgt_pat);
979                    }
980                    self.consume_place(place);
981                }
982                &Pat::Bind { id, subpat: _ } => {
983                    let mode = self.result.binding_modes[tgt_pat];
984                    let capture_kind = match mode {
985                        BindingMode::Move => {
986                            self.consume_place(place);
987                            break 'reset_span_stack;
988                        }
989                        BindingMode::Ref(Mutability::Not) => BorrowKind::Shared,
990                        BindingMode::Ref(Mutability::Mut) => {
991                            BorrowKind::Mut { kind: MutBorrowKind::Default }
992                        }
993                    };
994                    self.current_capture_span_stack.push(MirSpan::BindingId(id));
995                    self.add_capture(place, CaptureKind::ByRef(capture_kind));
996                    self.current_capture_span_stack.pop();
997                }
998                Pat::TupleStruct { path: _, args, ellipsis } => {
999                    let Some(variant) = self.result.variant_resolution_for_pat(tgt_pat) else {
1000                        break 'reset_span_stack;
1001                    };
1002                    match variant {
1003                        VariantId::EnumVariantId(_) | VariantId::UnionId(_) => {
1004                            self.consume_place(place)
1005                        }
1006                        VariantId::StructId(s) => {
1007                            let vd = s.fields(self.db);
1008                            let (al, ar) =
1009                                args.split_at(ellipsis.map_or(args.len(), |it| it as usize));
1010                            let fields = vd.fields().iter();
1011                            let it = al
1012                                .iter()
1013                                .zip(fields.clone())
1014                                .chain(ar.iter().rev().zip(fields.rev()));
1015                            for (&arg, (i, _)) in it {
1016                                let mut p = place.clone();
1017                                self.current_capture_span_stack.push(MirSpan::PatId(arg));
1018                                p.projections.push(ProjectionElem::Field(Either::Left(FieldId {
1019                                    parent: variant,
1020                                    local_id: i,
1021                                })));
1022                                self.consume_with_pat(p, arg);
1023                                self.current_capture_span_stack.pop();
1024                            }
1025                        }
1026                    }
1027                }
1028                Pat::Ref { pat, mutability: _ } => {
1029                    self.current_capture_span_stack.push(MirSpan::PatId(tgt_pat));
1030                    place.projections.push(ProjectionElem::Deref);
1031                    self.consume_with_pat(place, *pat);
1032                    self.current_capture_span_stack.pop();
1033                }
1034                Pat::Box { .. } => (), // not supported
1035                &Pat::Expr(expr) => {
1036                    self.consume_place(place);
1037                    let pat_capture_span_stack = mem::take(&mut self.current_capture_span_stack);
1038                    let old_inside_assignment = mem::replace(&mut self.inside_assignment, false);
1039                    let lhs_place = self.place_of_expr(expr);
1040                    self.mutate_expr(expr, lhs_place);
1041                    self.inside_assignment = old_inside_assignment;
1042                    self.current_capture_span_stack = pat_capture_span_stack;
1043                }
1044            }
1045        }
1046        self.current_capture_span_stack
1047            .truncate(self.current_capture_span_stack.len() - adjustments_count);
1048    }
1049
1050    fn consume_exprs(&mut self, exprs: impl Iterator<Item = ExprId>) {
1051        for expr in exprs {
1052            self.consume_expr(expr);
1053        }
1054    }
1055
1056    fn closure_kind(&self) -> FnTrait {
1057        let mut r = FnTrait::Fn;
1058        for it in &self.current_captures {
1059            r = cmp::min(
1060                r,
1061                match &it.kind {
1062                    CaptureKind::ByRef(BorrowKind::Mut { .. }) => FnTrait::FnMut,
1063                    CaptureKind::ByRef(BorrowKind::Shallow | BorrowKind::Shared) => FnTrait::Fn,
1064                    CaptureKind::ByValue => FnTrait::FnOnce,
1065                },
1066            )
1067        }
1068        r
1069    }
1070
1071    fn analyze_closure(&mut self, closure: InternedClosureId) -> FnTrait {
1072        let InternedClosure(_, root) = self.db.lookup_intern_closure(closure);
1073        self.current_closure = Some(closure);
1074        let Expr::Closure { body, capture_by, .. } = &self.body[root] else {
1075            unreachable!("Closure expression id is always closure");
1076        };
1077        self.consume_expr(*body);
1078        for item in &self.current_captures {
1079            if matches!(
1080                item.kind,
1081                CaptureKind::ByRef(BorrowKind::Mut {
1082                    kind: MutBorrowKind::Default | MutBorrowKind::TwoPhasedBorrow
1083                })
1084            ) && !item.place.projections.contains(&ProjectionElem::Deref)
1085            {
1086                // FIXME: remove the `mutated_bindings_in_closure` completely and add proper fake reads in
1087                // MIR. I didn't do that due duplicate diagnostics.
1088                self.result.mutated_bindings_in_closure.insert(item.place.local);
1089            }
1090        }
1091        self.restrict_precision_for_unsafe();
1092        // `closure_kind` should be done before adjust_for_move_closure
1093        // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
1094        // rustc also does diagnostics here if the latter is not a subtype of the former.
1095        let closure_kind = self
1096            .result
1097            .closure_info
1098            .get(&closure)
1099            .map_or_else(|| self.closure_kind(), |info| info.1);
1100        match capture_by {
1101            CaptureBy::Value => self.adjust_for_move_closure(),
1102            CaptureBy::Ref => (),
1103        }
1104        self.minimize_captures();
1105        self.strip_captures_ref_span();
1106        let result = mem::take(&mut self.current_captures);
1107        let captures = result.into_iter().map(|it| it.with_ty(self)).collect::<Vec<_>>();
1108        self.result.closure_info.insert(closure, (captures, closure_kind));
1109        closure_kind
1110    }
1111
1112    fn strip_captures_ref_span(&mut self) {
1113        // FIXME: Borrow checker won't allow without this.
1114        let mut captures = std::mem::take(&mut self.current_captures);
1115        for capture in &mut captures {
1116            if matches!(capture.kind, CaptureKind::ByValue) {
1117                for span_stack in &mut capture.span_stacks {
1118                    if span_stack[span_stack.len() - 1].is_ref_span(self.body) {
1119                        span_stack.truncate(span_stack.len() - 1);
1120                    }
1121                }
1122            }
1123        }
1124        self.current_captures = captures;
1125    }
1126
1127    pub(crate) fn infer_closures(&mut self) {
1128        let deferred_closures = self.sort_closures();
1129        for (closure, exprs) in deferred_closures.into_iter().rev() {
1130            self.current_captures = vec![];
1131            let kind = self.analyze_closure(closure);
1132
1133            for (derefed_callee, callee_ty, params, expr) in exprs {
1134                if let &Expr::Call { callee, .. } = &self.body[expr] {
1135                    let mut adjustments =
1136                        self.result.expr_adjustments.remove(&callee).unwrap_or_default().into_vec();
1137                    self.write_fn_trait_method_resolution(
1138                        kind,
1139                        derefed_callee,
1140                        &mut adjustments,
1141                        callee_ty,
1142                        &params,
1143                        expr,
1144                    );
1145                    self.result.expr_adjustments.insert(callee, adjustments.into_boxed_slice());
1146                }
1147            }
1148        }
1149    }
1150
1151    /// We want to analyze some closures before others, to have a correct analysis:
1152    /// * We should analyze nested closures before the parent, since the parent should capture some of
1153    ///   the things that its children captures.
1154    /// * If a closure calls another closure, we need to analyze the callee, to find out how we should
1155    ///   capture it (e.g. by move for FnOnce)
1156    ///
1157    /// These dependencies are collected in the main inference. We do a topological sort in this function. It
1158    /// will consume the `deferred_closures` field and return its content in a sorted vector.
1159    fn sort_closures(
1160        &mut self,
1161    ) -> Vec<(InternedClosureId, Vec<(Ty<'db>, Ty<'db>, Vec<Ty<'db>>, ExprId)>)> {
1162        let mut deferred_closures = mem::take(&mut self.deferred_closures);
1163        let mut dependents_count: FxHashMap<InternedClosureId, usize> =
1164            deferred_closures.keys().map(|it| (*it, 0)).collect();
1165        for deps in self.closure_dependencies.values() {
1166            for dep in deps {
1167                *dependents_count.entry(*dep).or_default() += 1;
1168            }
1169        }
1170        let mut queue: Vec<_> =
1171            deferred_closures.keys().copied().filter(|&it| dependents_count[&it] == 0).collect();
1172        let mut result = vec![];
1173        while let Some(it) = queue.pop() {
1174            if let Some(d) = deferred_closures.remove(&it) {
1175                result.push((it, d));
1176            }
1177            for &dep in self.closure_dependencies.get(&it).into_iter().flat_map(|it| it.iter()) {
1178                let cnt = dependents_count.get_mut(&dep).unwrap();
1179                *cnt -= 1;
1180                if *cnt == 0 {
1181                    queue.push(dep);
1182                }
1183            }
1184        }
1185        assert!(deferred_closures.is_empty(), "we should have analyzed all closures");
1186        result
1187    }
1188
1189    pub(crate) fn add_current_closure_dependency(&mut self, dep: InternedClosureId) {
1190        if let Some(c) = self.current_closure
1191            && !dep_creates_cycle(&self.closure_dependencies, &mut FxHashSet::default(), c, dep)
1192        {
1193            self.closure_dependencies.entry(c).or_default().push(dep);
1194        }
1195
1196        fn dep_creates_cycle(
1197            closure_dependencies: &FxHashMap<InternedClosureId, Vec<InternedClosureId>>,
1198            visited: &mut FxHashSet<InternedClosureId>,
1199            from: InternedClosureId,
1200            to: InternedClosureId,
1201        ) -> bool {
1202            if !visited.insert(from) {
1203                return false;
1204            }
1205
1206            if from == to {
1207                return true;
1208            }
1209
1210            if let Some(deps) = closure_dependencies.get(&to) {
1211                for dep in deps {
1212                    if dep_creates_cycle(closure_dependencies, visited, from, *dep) {
1213                        return true;
1214                    }
1215                }
1216            }
1217
1218            false
1219        }
1220    }
1221}
1222
1223/// Call this only when the last span in the stack isn't a split.
1224fn apply_adjusts_to_place<'db>(
1225    current_capture_span_stack: &mut Vec<MirSpan>,
1226    mut r: HirPlace<'db>,
1227    adjustments: &[Adjustment<'db>],
1228) -> Option<HirPlace<'db>> {
1229    let span = *current_capture_span_stack.last().expect("empty capture span stack");
1230    for adj in adjustments {
1231        match &adj.kind {
1232            Adjust::Deref(None) => {
1233                current_capture_span_stack.push(span);
1234                r.projections.push(ProjectionElem::Deref);
1235            }
1236            _ => return None,
1237        }
1238    }
1239    Some(r)
1240}