hir_ty/next_solver/infer/snapshot/
fudge.rs

1use std::ops::Range;
2
3use ena::{
4    snapshot_vec as sv,
5    unify::{self as ut, UnifyKey},
6};
7use rustc_type_ir::{
8    ConstVid, FloatVid, IntVid, RegionKind, RegionVid, TyVid, TypeFoldable, TypeFolder,
9    TypeSuperFoldable, TypeVisitableExt, inherent::IntoKind,
10};
11
12use crate::next_solver::{
13    Const, ConstKind, DbInterner, Region, Ty, TyKind,
14    infer::{
15        InferCtxt, UnificationTable, iter_idx_range,
16        snapshot::VariableLengths,
17        type_variable::TypeVariableOrigin,
18        unify_key::{ConstVariableOrigin, ConstVariableValue, ConstVidKey},
19    },
20};
21
22fn vars_since_snapshot<'db, T>(
23    table: &UnificationTable<'_, 'db, T>,
24    snapshot_var_len: usize,
25) -> Range<T>
26where
27    T: UnifyKey,
28    super::UndoLog<'db>: From<sv::UndoLog<ut::Delegate<T>>>,
29{
30    T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
31}
32
33fn const_vars_since_snapshot<'db>(
34    table: &mut UnificationTable<'_, 'db, ConstVidKey<'db>>,
35    snapshot_var_len: usize,
36) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
37    let range = vars_since_snapshot(table, snapshot_var_len);
38    let range = range.start.vid..range.end.vid;
39
40    (
41        range.clone(),
42        iter_idx_range(range)
43            .map(|index| match table.probe_value(index) {
44                ConstVariableValue::Known { value: _ } => ConstVariableOrigin {},
45                ConstVariableValue::Unknown { origin, universe: _ } => origin,
46            })
47            .collect(),
48    )
49}
50
51impl<'db> InferCtxt<'db> {
52    /// This rather funky routine is used while processing expected
53    /// types. What happens here is that we want to propagate a
54    /// coercion through the return type of a fn to its
55    /// argument. Consider the type of `Option::Some`, which is
56    /// basically `for<T> fn(T) -> Option<T>`. So if we have an
57    /// expression `Some(&[1, 2, 3])`, and that has the expected type
58    /// `Option<&[u32]>`, we would like to type check `&[1, 2, 3]`
59    /// with the expectation of `&[u32]`. This will cause us to coerce
60    /// from `&[u32; 3]` to `&[u32]` and make the users life more
61    /// pleasant.
62    ///
63    /// The way we do this is using `fudge_inference_if_ok`. What the
64    /// routine actually does is to start a snapshot and execute the
65    /// closure `f`. In our example above, what this closure will do
66    /// is to unify the expectation (`Option<&[u32]>`) with the actual
67    /// return type (`Option<?T>`, where `?T` represents the variable
68    /// instantiated for `T`). This will cause `?T` to be unified
69    /// with `&?a [u32]`, where `?a` is a fresh lifetime variable. The
70    /// input type (`?T`) is then returned by `f()`.
71    ///
72    /// At this point, `fudge_inference_if_ok` will normalize all type
73    /// variables, converting `?T` to `&?a [u32]` and end the
74    /// snapshot. The problem is that we can't just return this type
75    /// out, because it references the region variable `?a`, and that
76    /// region variable was popped when we popped the snapshot.
77    ///
78    /// So what we do is to keep a list (`region_vars`, in the code below)
79    /// of region variables created during the snapshot (here, `?a`). We
80    /// fold the return value and replace any such regions with a *new*
81    /// region variable (e.g., `?b`) and return the result (`&?b [u32]`).
82    /// This can then be used as the expectation for the fn argument.
83    ///
84    /// The important point here is that, for soundness purposes, the
85    /// regions in question are not particularly important. We will
86    /// use the expected types to guide coercions, but we will still
87    /// type-check the resulting types from those coercions against
88    /// the actual types (`?T`, `Option<?T>`) -- and remember that
89    /// after the snapshot is popped, the variable `?T` is no longer
90    /// unified.
91    pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
92    where
93        F: FnOnce() -> Result<T, E>,
94        T: TypeFoldable<DbInterner<'db>>,
95    {
96        let variable_lengths = self.variable_lengths();
97        let (snapshot_vars, value) = self.probe(|_| {
98            let value = f()?;
99            // At this point, `value` could in principle refer
100            // to inference variables that have been created during
101            // the snapshot. Once we exit `probe()`, those are
102            // going to be popped, so we will have to
103            // eliminate any references to them.
104            let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
105            Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
106        })?;
107
108        // At this point, we need to replace any of the now-popped
109        // type/region variables that appear in `value` with a fresh
110        // variable of the appropriate kind. We can't do this during
111        // the probe because they would just get popped then too. =)
112        Ok(self.fudge_inference(snapshot_vars, value))
113    }
114
115    fn fudge_inference<T: TypeFoldable<DbInterner<'db>>>(
116        &self,
117        snapshot_vars: SnapshotVarData,
118        value: T,
119    ) -> T {
120        // Micro-optimization: if no variables have been created, then
121        // `value` can't refer to any of them. =) So we can just return it.
122        if snapshot_vars.is_empty() {
123            value
124        } else {
125            value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
126        }
127    }
128}
129
130struct SnapshotVarData {
131    region_vars: Range<RegionVid>,
132    type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
133    int_vars: Range<IntVid>,
134    float_vars: Range<FloatVid>,
135    const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
136}
137
138impl SnapshotVarData {
139    fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
140        let mut inner = infcx.inner.borrow_mut();
141        let region_vars = inner
142            .unwrap_region_constraints()
143            .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
144        let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
145        let int_vars =
146            vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
147        let float_vars =
148            vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
149
150        let const_vars = const_vars_since_snapshot(
151            &mut inner.const_unification_table(),
152            vars_pre_snapshot.const_var_len,
153        );
154        SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
155    }
156
157    fn is_empty(&self) -> bool {
158        let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
159        region_vars.is_empty()
160            && type_vars.0.is_empty()
161            && int_vars.is_empty()
162            && float_vars.is_empty()
163            && const_vars.0.is_empty()
164    }
165}
166
167struct InferenceFudger<'a, 'db> {
168    infcx: &'a InferCtxt<'db>,
169    snapshot_vars: SnapshotVarData,
170}
171
172impl<'a, 'db> TypeFolder<DbInterner<'db>> for InferenceFudger<'a, 'db> {
173    fn cx(&self) -> DbInterner<'db> {
174        self.infcx.interner
175    }
176
177    fn fold_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
178        if let TyKind::Infer(infer_ty) = ty.kind() {
179            match infer_ty {
180                rustc_type_ir::TyVar(vid) => {
181                    if self.snapshot_vars.type_vars.0.contains(&vid) {
182                        // This variable was created during the fudging.
183                        // Recreate it with a fresh variable here.
184                        let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
185                        let origin = self.snapshot_vars.type_vars.1[idx];
186                        self.infcx.next_ty_var_with_origin(origin)
187                    } else {
188                        // This variable was created before the
189                        // "fudging". Since we refresh all type
190                        // variables to their binding anyhow, we know
191                        // that it is unbound, so we can just return
192                        // it.
193                        debug_assert!(
194                            self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
195                        );
196                        ty
197                    }
198                }
199                rustc_type_ir::IntVar(vid) => {
200                    if self.snapshot_vars.int_vars.contains(&vid) {
201                        self.infcx.next_int_var()
202                    } else {
203                        ty
204                    }
205                }
206                rustc_type_ir::FloatVar(vid) => {
207                    if self.snapshot_vars.float_vars.contains(&vid) {
208                        self.infcx.next_float_var()
209                    } else {
210                        ty
211                    }
212                }
213                rustc_type_ir::FreshTy(_)
214                | rustc_type_ir::FreshIntTy(_)
215                | rustc_type_ir::FreshFloatTy(_) => {
216                    unreachable!("unexpected fresh infcx var")
217                }
218            }
219        } else if ty.has_infer() {
220            ty.super_fold_with(self)
221        } else {
222            ty
223        }
224    }
225
226    fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
227        if let RegionKind::ReVar(vid) = r.kind() {
228            if self.snapshot_vars.region_vars.contains(&vid) {
229                self.infcx.next_region_var()
230            } else {
231                r
232            }
233        } else {
234            r
235        }
236    }
237
238    fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
239        if let ConstKind::Infer(infer_ct) = ct.kind() {
240            match infer_ct {
241                rustc_type_ir::InferConst::Var(vid) => {
242                    if self.snapshot_vars.const_vars.0.contains(&vid) {
243                        let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
244                        let origin = self.snapshot_vars.const_vars.1[idx];
245                        self.infcx.next_const_var_with_origin(origin)
246                    } else {
247                        ct
248                    }
249                }
250                rustc_type_ir::InferConst::Fresh(_) => {
251                    unreachable!("unexpected fresh infcx var")
252                }
253            }
254        } else if ct.has_infer() {
255            ct.super_fold_with(self)
256        } else {
257            ct
258        }
259    }
260}