1use std::{fmt::Debug, iter};
10
11use crate::next_solver::{
12 BoundConst, BoundRegion, BoundTy, Canonical, CanonicalVarKind, CanonicalVarValues, Clauses,
13 Const, ConstKind, DbInterner, GenericArg, ParamEnv, Predicate, Region, RegionKind, Ty, TyKind,
14 fold::FnMutDelegate,
15 infer::{
16 InferCtxt, InferOk, InferResult,
17 canonical::{QueryRegionConstraints, QueryResponse, canonicalizer::OriginalQueryValues},
18 traits::{ObligationCause, PredicateObligations},
19 },
20};
21use rustc_hash::FxHashMap;
22use rustc_index::{Idx as _, IndexVec};
23use rustc_type_ir::{
24 BoundVar, BoundVarIndexKind, GenericArgKind, TypeFlags, TypeFoldable, TypeFolder,
25 TypeSuperFoldable, TypeVisitableExt, UniverseIndex,
26 inherent::{GenericArg as _, IntoKind, SliceLike},
27};
28use tracing::{debug, instrument};
29
30pub trait CanonicalExt<'db, V> {
31 fn instantiate(&self, tcx: DbInterner<'db>, var_values: &CanonicalVarValues<'db>) -> V
32 where
33 V: TypeFoldable<DbInterner<'db>>;
34 fn instantiate_projected<T>(
35 &self,
36 tcx: DbInterner<'db>,
37 var_values: &CanonicalVarValues<'db>,
38 projection_fn: impl FnOnce(&V) -> T,
39 ) -> T
40 where
41 T: TypeFoldable<DbInterner<'db>>;
42}
43
44impl<'db, V> CanonicalExt<'db, V> for Canonical<'db, V> {
47 fn instantiate(&self, tcx: DbInterner<'db>, var_values: &CanonicalVarValues<'db>) -> V
50 where
51 V: TypeFoldable<DbInterner<'db>>,
52 {
53 self.instantiate_projected(tcx, var_values, |value| value.clone())
54 }
55
56 fn instantiate_projected<T>(
63 &self,
64 tcx: DbInterner<'db>,
65 var_values: &CanonicalVarValues<'db>,
66 projection_fn: impl FnOnce(&V) -> T,
67 ) -> T
68 where
69 T: TypeFoldable<DbInterner<'db>>,
70 {
71 assert_eq!(self.variables.len(), var_values.len());
72 let value = projection_fn(&self.value);
73 instantiate_value(tcx, var_values, value)
74 }
75}
76
77pub(super) fn instantiate_value<'db, T>(
81 tcx: DbInterner<'db>,
82 var_values: &CanonicalVarValues<'db>,
83 value: T,
84) -> T
85where
86 T: TypeFoldable<DbInterner<'db>>,
87{
88 if var_values.var_values.is_empty() {
89 value
90 } else {
91 let delegate = FnMutDelegate {
92 regions: &mut |br: BoundRegion| match var_values[br.var].kind() {
93 GenericArgKind::Lifetime(l) => l,
94 r => panic!("{br:?} is a region but value is {r:?}"),
95 },
96 types: &mut |bound_ty: BoundTy| match var_values[bound_ty.var].kind() {
97 GenericArgKind::Type(ty) => ty,
98 r => panic!("{bound_ty:?} is a type but value is {r:?}"),
99 },
100 consts: &mut |bound_ct: BoundConst| match var_values[bound_ct.var].kind() {
101 GenericArgKind::Const(ct) => ct,
102 c => panic!("{bound_ct:?} is a const but value is {c:?}"),
103 },
104 };
105
106 let value = tcx.replace_escaping_bound_vars_uncached(value, delegate);
107 value.fold_with(&mut CanonicalInstantiator {
108 tcx,
109 var_values: var_values.var_values.as_slice(),
110 cache: Default::default(),
111 })
112 }
113}
114
115struct CanonicalInstantiator<'db, 'a> {
117 tcx: DbInterner<'db>,
118
119 var_values: &'a [GenericArg<'db>],
121
122 cache: FxHashMap<Ty<'db>, Ty<'db>>,
125}
126
127impl<'db, 'a> TypeFolder<DbInterner<'db>> for CanonicalInstantiator<'db, 'a> {
128 fn cx(&self) -> DbInterner<'db> {
129 self.tcx
130 }
131
132 fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
133 match t.kind() {
134 TyKind::Bound(BoundVarIndexKind::Canonical, bound_ty) => {
135 self.var_values[bound_ty.var.as_usize()].expect_ty()
136 }
137 _ => {
138 if !t.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
139 t
140 } else if let Some(&t) = self.cache.get(&t) {
141 t
142 } else {
143 let res = t.super_fold_with(self);
144 assert!(self.cache.insert(t, res).is_none());
145 res
146 }
147 }
148 }
149 }
150
151 fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
152 match r.kind() {
153 RegionKind::ReBound(BoundVarIndexKind::Canonical, br) => {
154 self.var_values[br.var.as_usize()].expect_region()
155 }
156 _ => r,
157 }
158 }
159
160 fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
161 match ct.kind() {
162 ConstKind::Bound(BoundVarIndexKind::Canonical, bound_const) => {
163 self.var_values[bound_const.var.as_usize()].expect_const()
164 }
165 _ => ct.super_fold_with(self),
166 }
167 }
168
169 fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
170 if p.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) { p.super_fold_with(self) } else { p }
171 }
172
173 fn fold_clauses(&mut self, c: Clauses<'db>) -> Clauses<'db> {
174 if !c.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
175 return c;
176 }
177
178 c.super_fold_with(self)
180 }
181}
182
183impl<'db> InferCtxt<'db> {
184 pub fn make_query_response_ignoring_pending_obligations<T>(
194 &self,
195 inference_vars: CanonicalVarValues<'db>,
196 answer: T,
197 ) -> Canonical<'db, QueryResponse<'db, T>>
198 where
199 T: TypeFoldable<DbInterner<'db>>,
200 {
201 let opaque_types = self
209 .inner
210 .borrow_mut()
211 .opaque_type_storage
212 .iter_opaque_types()
213 .map(|(k, v)| (k, v.ty))
214 .collect();
215
216 self.canonicalize_response(QueryResponse {
217 var_values: inference_vars,
218 region_constraints: QueryRegionConstraints::default(),
219 opaque_types,
220 value: answer,
221 })
222 }
223
224 pub fn instantiate_query_response_and_region_obligations<R>(
235 &self,
236 cause: &ObligationCause,
237 param_env: ParamEnv<'db>,
238 original_values: &OriginalQueryValues<'db>,
239 query_response: &Canonical<'db, QueryResponse<'db, R>>,
240 ) -> InferResult<'db, R>
241 where
242 R: TypeFoldable<DbInterner<'db>>,
243 {
244 let InferOk { value: result_args, obligations } =
245 self.query_response_instantiation(cause, param_env, original_values, query_response)?;
246
247 for predicate in &query_response.value.region_constraints.outlives {
248 let predicate = instantiate_value(self.interner, &result_args, *predicate);
249 self.register_outlives_constraint(predicate);
250 }
251
252 for assumption in &query_response.value.region_constraints.assumptions {
253 let assumption = instantiate_value(self.interner, &result_args, *assumption);
254 self.register_region_assumption(assumption);
255 }
256
257 let user_result: R =
258 query_response
259 .instantiate_projected(self.interner, &result_args, |q_r| q_r.value.clone());
260
261 Ok(InferOk { value: user_result, obligations })
262 }
263
264 fn query_response_instantiation<R>(
275 &self,
276 cause: &ObligationCause,
277 param_env: ParamEnv<'db>,
278 original_values: &OriginalQueryValues<'db>,
279 query_response: &Canonical<'db, QueryResponse<'db, R>>,
280 ) -> InferResult<'db, CanonicalVarValues<'db>>
281 where
282 R: Debug + TypeFoldable<DbInterner<'db>>,
283 {
284 debug!(
285 "query_response_instantiation(original_values={:#?}, query_response={:#?})",
286 original_values, query_response,
287 );
288
289 let mut value = self.query_response_instantiation_guess(
290 cause,
291 param_env,
292 original_values,
293 query_response,
294 )?;
295
296 value.obligations.extend(
297 self.unify_query_response_instantiation_guess(
298 cause,
299 param_env,
300 original_values,
301 &value.value,
302 query_response,
303 )?
304 .into_obligations(),
305 );
306
307 Ok(value)
308 }
309
310 #[instrument(level = "debug", skip(self, param_env))]
320 fn query_response_instantiation_guess<R>(
321 &self,
322 cause: &ObligationCause,
323 param_env: ParamEnv<'db>,
324 original_values: &OriginalQueryValues<'db>,
325 query_response: &Canonical<'db, QueryResponse<'db, R>>,
326 ) -> InferResult<'db, CanonicalVarValues<'db>>
327 where
328 R: Debug + TypeFoldable<DbInterner<'db>>,
329 {
330 let mut universe_map = original_values.universe_map.clone();
334 let num_universes_in_query = original_values.universe_map.len();
335 let num_universes_in_response = query_response.max_universe.as_usize() + 1;
336 for _ in num_universes_in_query..num_universes_in_response {
337 universe_map.push(self.create_next_universe());
338 }
339 assert!(!universe_map.is_empty()); assert_eq!(universe_map[UniverseIndex::ROOT.as_usize()], UniverseIndex::ROOT);
341
342 let result_values = &query_response.value.var_values;
347 assert_eq!(original_values.var_values.len(), result_values.len());
348
349 let mut opt_values: IndexVec<BoundVar, Option<GenericArg<'db>>> =
357 IndexVec::from_elem_n(None, query_response.variables.len());
358
359 for (original_value, result_value) in iter::zip(&original_values.var_values, result_values)
360 {
361 match result_value.kind() {
362 GenericArgKind::Type(result_value) => {
363 if let TyKind::Bound(index_kind, b) = result_value.kind()
368 && !matches!(
369 query_response.variables.as_slice()[b.var.as_usize()],
370 CanonicalVarKind::Ty { .. }
371 )
372 {
373 assert!(matches!(index_kind, BoundVarIndexKind::Canonical));
375 opt_values[b.var] = Some(*original_value);
376 }
377 }
378 GenericArgKind::Lifetime(result_value) => {
379 if let RegionKind::ReBound(index_kind, b) = result_value.kind() {
380 assert!(matches!(index_kind, BoundVarIndexKind::Canonical));
382 opt_values[b.var] = Some(*original_value);
383 }
384 }
385 GenericArgKind::Const(result_value) => {
386 if let ConstKind::Bound(index_kind, b) = result_value.kind() {
387 assert!(matches!(index_kind, BoundVarIndexKind::Canonical));
389 opt_values[b.var] = Some(*original_value);
390 }
391 }
392 }
393 }
394
395 let interner = self.interner;
399 let variables = query_response.variables;
400 let var_values =
401 CanonicalVarValues::instantiate(interner, variables, |var_values, kind| {
402 if kind.universe() != UniverseIndex::ROOT {
403 self.instantiate_canonical_var(kind, var_values, |u| universe_map[u.as_usize()])
406 } else if kind.is_existential() {
407 match opt_values[BoundVar::new(var_values.len())] {
408 Some(k) => k,
409 None => self.instantiate_canonical_var(kind, var_values, |u| {
410 universe_map[u.as_usize()]
411 }),
412 }
413 } else {
414 opt_values[BoundVar::new(var_values.len())]
417 .expect("expected placeholder to be unified with itself during response")
418 }
419 });
420
421 let mut obligations = PredicateObligations::new();
422
423 for &(a, b) in &query_response.value.opaque_types {
425 let a = instantiate_value(self.interner, &var_values, a);
426 let b = instantiate_value(self.interner, &var_values, b);
427 debug!(?a, ?b, "constrain opaque type");
428 obligations.extend(
433 self.at(cause, param_env)
434 .eq(Ty::new_opaque(self.interner, a.def_id, a.args), b)?
435 .obligations,
436 );
437 }
438
439 Ok(InferOk { value: var_values, obligations })
440 }
441
442 fn unify_query_response_instantiation_guess<R>(
449 &self,
450 cause: &ObligationCause,
451 param_env: ParamEnv<'db>,
452 original_values: &OriginalQueryValues<'db>,
453 result_args: &CanonicalVarValues<'db>,
454 query_response: &Canonical<'db, QueryResponse<'db, R>>,
455 ) -> InferResult<'db, ()>
456 where
457 R: Debug + TypeFoldable<DbInterner<'db>>,
458 {
459 let instantiated_query_response = |index: BoundVar| -> GenericArg<'db> {
464 query_response
465 .instantiate_projected(self.interner, result_args, |v| v.var_values[index])
466 };
467
468 self.unify_canonical_vars(cause, param_env, original_values, instantiated_query_response)
471 }
472
473 fn unify_canonical_vars(
476 &self,
477 cause: &ObligationCause,
478 param_env: ParamEnv<'db>,
479 variables1: &OriginalQueryValues<'db>,
480 variables2: impl Fn(BoundVar) -> GenericArg<'db>,
481 ) -> InferResult<'db, ()> {
482 let mut obligations = PredicateObligations::new();
483 for (index, value1) in variables1.var_values.iter().enumerate() {
484 let value2 = variables2(BoundVar::new(index));
485
486 match (value1.kind(), value2.kind()) {
487 (GenericArgKind::Type(v1), GenericArgKind::Type(v2)) => {
488 obligations.extend(self.at(cause, param_env).eq(v1, v2)?.into_obligations());
489 }
490 (GenericArgKind::Lifetime(re1), GenericArgKind::Lifetime(re2))
491 if re1.is_erased() && re2.is_erased() =>
492 {
493 }
495 (GenericArgKind::Lifetime(v1), GenericArgKind::Lifetime(v2)) => {
496 self.inner.borrow_mut().unwrap_region_constraints().make_eqregion(v1, v2);
497 }
498 (GenericArgKind::Const(v1), GenericArgKind::Const(v2)) => {
499 let ok = self.at(cause, param_env).eq(v1, v2)?;
500 obligations.extend(ok.into_obligations());
501 }
502 _ => {
503 panic!("kind mismatch, cannot unify {:?} and {:?}", value1, value2,);
504 }
505 }
506 }
507 Ok(InferOk { value: (), obligations })
508 }
509}