hir_ty/next_solver/infer/canonical/
instantiate.rs

1//! This module contains code to instantiate new values into a
2//! `Canonical<'db, T>`.
3//!
4//! For an overview of what canonicalization is and how it fits into
5//! rustc, check out the [chapter in the rustc dev guide][c].
6//!
7//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
8
9use std::{fmt::Debug, iter};
10
11use crate::next_solver::{
12    BoundConst, BoundRegion, BoundTy, Canonical, CanonicalVarKind, CanonicalVarValues, Clauses,
13    Const, ConstKind, DbInterner, GenericArg, ParamEnv, Predicate, Region, RegionKind, Ty, TyKind,
14    fold::FnMutDelegate,
15    infer::{
16        InferCtxt, InferOk, InferResult,
17        canonical::{QueryRegionConstraints, QueryResponse, canonicalizer::OriginalQueryValues},
18        traits::{ObligationCause, PredicateObligations},
19    },
20};
21use rustc_hash::FxHashMap;
22use rustc_index::{Idx as _, IndexVec};
23use rustc_type_ir::{
24    BoundVar, BoundVarIndexKind, GenericArgKind, TypeFlags, TypeFoldable, TypeFolder,
25    TypeSuperFoldable, TypeVisitableExt, UniverseIndex,
26    inherent::{GenericArg as _, IntoKind, SliceLike},
27};
28use tracing::{debug, instrument};
29
30pub trait CanonicalExt<'db, V> {
31    fn instantiate(&self, tcx: DbInterner<'db>, var_values: &CanonicalVarValues<'db>) -> V
32    where
33        V: TypeFoldable<DbInterner<'db>>;
34    fn instantiate_projected<T>(
35        &self,
36        tcx: DbInterner<'db>,
37        var_values: &CanonicalVarValues<'db>,
38        projection_fn: impl FnOnce(&V) -> T,
39    ) -> T
40    where
41        T: TypeFoldable<DbInterner<'db>>;
42}
43
44/// FIXME(-Znext-solver): This or public because it is shared with the
45/// new trait solver implementation. We should deduplicate canonicalization.
46impl<'db, V> CanonicalExt<'db, V> for Canonical<'db, V> {
47    /// Instantiate the wrapped value, replacing each canonical value
48    /// with the value given in `var_values`.
49    fn instantiate(&self, tcx: DbInterner<'db>, var_values: &CanonicalVarValues<'db>) -> V
50    where
51        V: TypeFoldable<DbInterner<'db>>,
52    {
53        self.instantiate_projected(tcx, var_values, |value| value.clone())
54    }
55
56    /// Allows one to apply a instantiation to some subset of
57    /// `self.value`. Invoke `projection_fn` with `self.value` to get
58    /// a value V that is expressed in terms of the same canonical
59    /// variables bound in `self` (usually this extracts from subset
60    /// of `self`). Apply the instantiation `var_values` to this value
61    /// V, replacing each of the canonical variables.
62    fn instantiate_projected<T>(
63        &self,
64        tcx: DbInterner<'db>,
65        var_values: &CanonicalVarValues<'db>,
66        projection_fn: impl FnOnce(&V) -> T,
67    ) -> T
68    where
69        T: TypeFoldable<DbInterner<'db>>,
70    {
71        assert_eq!(self.variables.len(), var_values.len());
72        let value = projection_fn(&self.value);
73        instantiate_value(tcx, var_values, value)
74    }
75}
76
77/// Instantiate the values from `var_values` into `value`. `var_values`
78/// must be values for the set of canonical variables that appear in
79/// `value`.
80pub(super) fn instantiate_value<'db, T>(
81    tcx: DbInterner<'db>,
82    var_values: &CanonicalVarValues<'db>,
83    value: T,
84) -> T
85where
86    T: TypeFoldable<DbInterner<'db>>,
87{
88    if var_values.var_values.is_empty() {
89        value
90    } else {
91        let delegate = FnMutDelegate {
92            regions: &mut |br: BoundRegion| match var_values[br.var].kind() {
93                GenericArgKind::Lifetime(l) => l,
94                r => panic!("{br:?} is a region but value is {r:?}"),
95            },
96            types: &mut |bound_ty: BoundTy| match var_values[bound_ty.var].kind() {
97                GenericArgKind::Type(ty) => ty,
98                r => panic!("{bound_ty:?} is a type but value is {r:?}"),
99            },
100            consts: &mut |bound_ct: BoundConst| match var_values[bound_ct.var].kind() {
101                GenericArgKind::Const(ct) => ct,
102                c => panic!("{bound_ct:?} is a const but value is {c:?}"),
103            },
104        };
105
106        let value = tcx.replace_escaping_bound_vars_uncached(value, delegate);
107        value.fold_with(&mut CanonicalInstantiator {
108            tcx,
109            var_values: var_values.var_values.as_slice(),
110            cache: Default::default(),
111        })
112    }
113}
114
115/// Replaces the bound vars in a canonical binder with var values.
116struct CanonicalInstantiator<'db, 'a> {
117    tcx: DbInterner<'db>,
118
119    // The values that the bound vars are being instantiated with.
120    var_values: &'a [GenericArg<'db>],
121
122    // Because we use `BoundVarIndexKind::Canonical`, we can cache
123    // based only on the entire ty, not worrying about a `DebruijnIndex`
124    cache: FxHashMap<Ty<'db>, Ty<'db>>,
125}
126
127impl<'db, 'a> TypeFolder<DbInterner<'db>> for CanonicalInstantiator<'db, 'a> {
128    fn cx(&self) -> DbInterner<'db> {
129        self.tcx
130    }
131
132    fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
133        match t.kind() {
134            TyKind::Bound(BoundVarIndexKind::Canonical, bound_ty) => {
135                self.var_values[bound_ty.var.as_usize()].expect_ty()
136            }
137            _ => {
138                if !t.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
139                    t
140                } else if let Some(&t) = self.cache.get(&t) {
141                    t
142                } else {
143                    let res = t.super_fold_with(self);
144                    assert!(self.cache.insert(t, res).is_none());
145                    res
146                }
147            }
148        }
149    }
150
151    fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
152        match r.kind() {
153            RegionKind::ReBound(BoundVarIndexKind::Canonical, br) => {
154                self.var_values[br.var.as_usize()].expect_region()
155            }
156            _ => r,
157        }
158    }
159
160    fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
161        match ct.kind() {
162            ConstKind::Bound(BoundVarIndexKind::Canonical, bound_const) => {
163                self.var_values[bound_const.var.as_usize()].expect_const()
164            }
165            _ => ct.super_fold_with(self),
166        }
167    }
168
169    fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
170        if p.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) { p.super_fold_with(self) } else { p }
171    }
172
173    fn fold_clauses(&mut self, c: Clauses<'db>) -> Clauses<'db> {
174        if !c.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
175            return c;
176        }
177
178        // FIXME: We might need cache here for perf like rustc
179        c.super_fold_with(self)
180    }
181}
182
183impl<'db> InferCtxt<'db> {
184    /// A version of `make_canonicalized_query_response` that does
185    /// not pack in obligations, for contexts that want to drop
186    /// pending obligations instead of treating them as an ambiguity (e.g.
187    /// typeck "probing" contexts).
188    ///
189    /// If you DO want to keep track of pending obligations (which
190    /// include all region obligations, so this includes all cases
191    /// that care about regions) with this function, you have to
192    /// do it yourself, by e.g., having them be a part of the answer.
193    pub fn make_query_response_ignoring_pending_obligations<T>(
194        &self,
195        inference_vars: CanonicalVarValues<'db>,
196        answer: T,
197    ) -> Canonical<'db, QueryResponse<'db, T>>
198    where
199        T: TypeFoldable<DbInterner<'db>>,
200    {
201        // While we ignore region constraints and pending obligations,
202        // we do return constrained opaque types to avoid unconstrained
203        // inference variables in the response. This is important as we want
204        // to check that opaques in deref steps stay unconstrained.
205        //
206        // This doesn't handle the more general case for non-opaques as
207        // ambiguous `Projection` obligations have same the issue.
208        let opaque_types = self
209            .inner
210            .borrow_mut()
211            .opaque_type_storage
212            .iter_opaque_types()
213            .map(|(k, v)| (k, v.ty))
214            .collect();
215
216        self.canonicalize_response(QueryResponse {
217            var_values: inference_vars,
218            region_constraints: QueryRegionConstraints::default(),
219            opaque_types,
220            value: answer,
221        })
222    }
223
224    /// Given the (canonicalized) result to a canonical query,
225    /// instantiates the result so it can be used, plugging in the
226    /// values from the canonical query. (Note that the result may
227    /// have been ambiguous; you should check the certainty level of
228    /// the query before applying this function.)
229    ///
230    /// To get a good understanding of what is happening here, check
231    /// out the [chapter in the rustc dev guide][c].
232    ///
233    /// [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html#processing-the-canonicalized-query-result
234    pub fn instantiate_query_response_and_region_obligations<R>(
235        &self,
236        cause: &ObligationCause,
237        param_env: ParamEnv<'db>,
238        original_values: &OriginalQueryValues<'db>,
239        query_response: &Canonical<'db, QueryResponse<'db, R>>,
240    ) -> InferResult<'db, R>
241    where
242        R: TypeFoldable<DbInterner<'db>>,
243    {
244        let InferOk { value: result_args, obligations } =
245            self.query_response_instantiation(cause, param_env, original_values, query_response)?;
246
247        for predicate in &query_response.value.region_constraints.outlives {
248            let predicate = instantiate_value(self.interner, &result_args, *predicate);
249            self.register_outlives_constraint(predicate);
250        }
251
252        for assumption in &query_response.value.region_constraints.assumptions {
253            let assumption = instantiate_value(self.interner, &result_args, *assumption);
254            self.register_region_assumption(assumption);
255        }
256
257        let user_result: R =
258            query_response
259                .instantiate_projected(self.interner, &result_args, |q_r| q_r.value.clone());
260
261        Ok(InferOk { value: user_result, obligations })
262    }
263
264    /// Given the original values and the (canonicalized) result from
265    /// computing a query, returns an instantiation that can be applied
266    /// to the query result to convert the result back into the
267    /// original namespace.
268    ///
269    /// The instantiation also comes accompanied with subobligations
270    /// that arose from unification; these might occur if (for
271    /// example) we are doing lazy normalization and the value
272    /// assigned to a type variable is unified with an unnormalized
273    /// projection.
274    fn query_response_instantiation<R>(
275        &self,
276        cause: &ObligationCause,
277        param_env: ParamEnv<'db>,
278        original_values: &OriginalQueryValues<'db>,
279        query_response: &Canonical<'db, QueryResponse<'db, R>>,
280    ) -> InferResult<'db, CanonicalVarValues<'db>>
281    where
282        R: Debug + TypeFoldable<DbInterner<'db>>,
283    {
284        debug!(
285            "query_response_instantiation(original_values={:#?}, query_response={:#?})",
286            original_values, query_response,
287        );
288
289        let mut value = self.query_response_instantiation_guess(
290            cause,
291            param_env,
292            original_values,
293            query_response,
294        )?;
295
296        value.obligations.extend(
297            self.unify_query_response_instantiation_guess(
298                cause,
299                param_env,
300                original_values,
301                &value.value,
302                query_response,
303            )?
304            .into_obligations(),
305        );
306
307        Ok(value)
308    }
309
310    /// Given the original values and the (canonicalized) result from
311    /// computing a query, returns a **guess** at an instantiation that
312    /// can be applied to the query result to convert the result back
313    /// into the original namespace. This is called a **guess**
314    /// because it uses a quick heuristic to find the values for each
315    /// canonical variable; if that quick heuristic fails, then we
316    /// will instantiate fresh inference variables for each canonical
317    /// variable instead. Therefore, the result of this method must be
318    /// properly unified
319    #[instrument(level = "debug", skip(self, param_env))]
320    fn query_response_instantiation_guess<R>(
321        &self,
322        cause: &ObligationCause,
323        param_env: ParamEnv<'db>,
324        original_values: &OriginalQueryValues<'db>,
325        query_response: &Canonical<'db, QueryResponse<'db, R>>,
326    ) -> InferResult<'db, CanonicalVarValues<'db>>
327    where
328        R: Debug + TypeFoldable<DbInterner<'db>>,
329    {
330        // For each new universe created in the query result that did
331        // not appear in the original query, create a local
332        // superuniverse.
333        let mut universe_map = original_values.universe_map.clone();
334        let num_universes_in_query = original_values.universe_map.len();
335        let num_universes_in_response = query_response.max_universe.as_usize() + 1;
336        for _ in num_universes_in_query..num_universes_in_response {
337            universe_map.push(self.create_next_universe());
338        }
339        assert!(!universe_map.is_empty()); // always have the root universe
340        assert_eq!(universe_map[UniverseIndex::ROOT.as_usize()], UniverseIndex::ROOT);
341
342        // Every canonical query result includes values for each of
343        // the inputs to the query. Therefore, we begin by unifying
344        // these values with the original inputs that were
345        // canonicalized.
346        let result_values = &query_response.value.var_values;
347        assert_eq!(original_values.var_values.len(), result_values.len());
348
349        // Quickly try to find initial values for the canonical
350        // variables in the result in terms of the query. We do this
351        // by iterating down the values that the query gave to each of
352        // the canonical inputs. If we find that one of those values
353        // is directly equal to one of the canonical variables in the
354        // result, then we can type the corresponding value from the
355        // input. See the example above.
356        let mut opt_values: IndexVec<BoundVar, Option<GenericArg<'db>>> =
357            IndexVec::from_elem_n(None, query_response.variables.len());
358
359        for (original_value, result_value) in iter::zip(&original_values.var_values, result_values)
360        {
361            match result_value.kind() {
362                GenericArgKind::Type(result_value) => {
363                    // We disable the instantiation guess for inference variables
364                    // and only use it for placeholders. We need to handle the
365                    // `sub_root` of type inference variables which would make this
366                    // more involved. They are also a lot rarer than region variables.
367                    if let TyKind::Bound(index_kind, b) = result_value.kind()
368                        && !matches!(
369                            query_response.variables.as_slice()[b.var.as_usize()],
370                            CanonicalVarKind::Ty { .. }
371                        )
372                    {
373                        // We only allow a `Canonical` index in generic parameters.
374                        assert!(matches!(index_kind, BoundVarIndexKind::Canonical));
375                        opt_values[b.var] = Some(*original_value);
376                    }
377                }
378                GenericArgKind::Lifetime(result_value) => {
379                    if let RegionKind::ReBound(index_kind, b) = result_value.kind() {
380                        // We only allow a `Canonical` index in generic parameters.
381                        assert!(matches!(index_kind, BoundVarIndexKind::Canonical));
382                        opt_values[b.var] = Some(*original_value);
383                    }
384                }
385                GenericArgKind::Const(result_value) => {
386                    if let ConstKind::Bound(index_kind, b) = result_value.kind() {
387                        // We only allow a `Canonical` index in generic parameters.
388                        assert!(matches!(index_kind, BoundVarIndexKind::Canonical));
389                        opt_values[b.var] = Some(*original_value);
390                    }
391                }
392            }
393        }
394
395        // Create result arguments: if we found a value for a
396        // given variable in the loop above, use that. Otherwise, use
397        // a fresh inference variable.
398        let interner = self.interner;
399        let variables = query_response.variables;
400        let var_values =
401            CanonicalVarValues::instantiate(interner, variables, |var_values, kind| {
402                if kind.universe() != UniverseIndex::ROOT {
403                    // A variable from inside a binder of the query. While ideally these shouldn't
404                    // exist at all, we have to deal with them for now.
405                    self.instantiate_canonical_var(kind, var_values, |u| universe_map[u.as_usize()])
406                } else if kind.is_existential() {
407                    match opt_values[BoundVar::new(var_values.len())] {
408                        Some(k) => k,
409                        None => self.instantiate_canonical_var(kind, var_values, |u| {
410                            universe_map[u.as_usize()]
411                        }),
412                    }
413                } else {
414                    // For placeholders which were already part of the input, we simply map this
415                    // universal bound variable back the placeholder of the input.
416                    opt_values[BoundVar::new(var_values.len())]
417                        .expect("expected placeholder to be unified with itself during response")
418                }
419            });
420
421        let mut obligations = PredicateObligations::new();
422
423        // Carry all newly resolved opaque types to the caller's scope
424        for &(a, b) in &query_response.value.opaque_types {
425            let a = instantiate_value(self.interner, &var_values, a);
426            let b = instantiate_value(self.interner, &var_values, b);
427            debug!(?a, ?b, "constrain opaque type");
428            // We use equate here instead of, for example, just registering the
429            // opaque type's hidden value directly, because the hidden type may have been an inference
430            // variable that got constrained to the opaque type itself. In that case we want to equate
431            // the generic args of the opaque with the generic params of its hidden type version.
432            obligations.extend(
433                self.at(cause, param_env)
434                    .eq(Ty::new_opaque(self.interner, a.def_id, a.args), b)?
435                    .obligations,
436            );
437        }
438
439        Ok(InferOk { value: var_values, obligations })
440    }
441
442    /// Given a "guess" at the values for the canonical variables in
443    /// the input, try to unify with the *actual* values found in the
444    /// query result. Often, but not always, this is a no-op, because
445    /// we already found the mapping in the "guessing" step.
446    ///
447    /// See also: [`Self::query_response_instantiation_guess`]
448    fn unify_query_response_instantiation_guess<R>(
449        &self,
450        cause: &ObligationCause,
451        param_env: ParamEnv<'db>,
452        original_values: &OriginalQueryValues<'db>,
453        result_args: &CanonicalVarValues<'db>,
454        query_response: &Canonical<'db, QueryResponse<'db, R>>,
455    ) -> InferResult<'db, ()>
456    where
457        R: Debug + TypeFoldable<DbInterner<'db>>,
458    {
459        // A closure that yields the result value for the given
460        // canonical variable; this is taken from
461        // `query_response.var_values` after applying the instantiation
462        // by `result_args`.
463        let instantiated_query_response = |index: BoundVar| -> GenericArg<'db> {
464            query_response
465                .instantiate_projected(self.interner, result_args, |v| v.var_values[index])
466        };
467
468        // Unify the original value for each variable with the value
469        // taken from `query_response` (after applying `result_args`).
470        self.unify_canonical_vars(cause, param_env, original_values, instantiated_query_response)
471    }
472
473    /// Given two sets of values for the same set of canonical variables, unify them.
474    /// The second set is produced lazily by supplying indices from the first set.
475    fn unify_canonical_vars(
476        &self,
477        cause: &ObligationCause,
478        param_env: ParamEnv<'db>,
479        variables1: &OriginalQueryValues<'db>,
480        variables2: impl Fn(BoundVar) -> GenericArg<'db>,
481    ) -> InferResult<'db, ()> {
482        let mut obligations = PredicateObligations::new();
483        for (index, value1) in variables1.var_values.iter().enumerate() {
484            let value2 = variables2(BoundVar::new(index));
485
486            match (value1.kind(), value2.kind()) {
487                (GenericArgKind::Type(v1), GenericArgKind::Type(v2)) => {
488                    obligations.extend(self.at(cause, param_env).eq(v1, v2)?.into_obligations());
489                }
490                (GenericArgKind::Lifetime(re1), GenericArgKind::Lifetime(re2))
491                    if re1.is_erased() && re2.is_erased() =>
492                {
493                    // no action needed
494                }
495                (GenericArgKind::Lifetime(v1), GenericArgKind::Lifetime(v2)) => {
496                    self.inner.borrow_mut().unwrap_region_constraints().make_eqregion(v1, v2);
497                }
498                (GenericArgKind::Const(v1), GenericArgKind::Const(v2)) => {
499                    let ok = self.at(cause, param_env).eq(v1, v2)?;
500                    obligations.extend(ok.into_obligations());
501                }
502                _ => {
503                    panic!("kind mismatch, cannot unify {:?} and {:?}", value1, value2,);
504                }
505            }
506        }
507        Ok(InferOk { value: (), obligations })
508    }
509}