1use std::fmt;
4
5use base_db::Crate;
6use hir_def::{AdtId, DefWithBodyId, GenericParamId};
7use hir_expand::name::Name;
8use intern::sym;
9use rustc_hash::FxHashSet;
10use rustc_type_ir::{
11 TyVid, TypeFoldable, TypeVisitableExt, UpcastFrom,
12 inherent::{Const as _, GenericArg as _, IntoKind, SliceLike, Ty as _},
13 solve::Certainty,
14};
15use smallvec::SmallVec;
16
17use crate::{
18 db::HirDatabase,
19 next_solver::{
20 AliasTy, Canonical, ClauseKind, Const, DbInterner, ErrorGuaranteed, GenericArg,
21 GenericArgs, Goal, ParamEnv, Predicate, PredicateKind, Region, SolverDefId, Term, TraitRef,
22 Ty, TyKind, TypingMode,
23 fulfill::{FulfillmentCtxt, NextSolverError},
24 infer::{
25 DbInternerInferExt, InferCtxt, InferOk, InferResult,
26 at::{At, ToTrace},
27 snapshot::CombinedSnapshot,
28 traits::{Obligation, ObligationCause, PredicateObligation},
29 },
30 inspect::{InspectConfig, InspectGoal, ProofTreeVisitor},
31 obligation_ctxt::ObligationCtxt,
32 },
33 traits::{
34 FnTrait, NextTraitSolveResult, ParamEnvAndCrate, next_trait_solve_canonical_in_ctxt,
35 next_trait_solve_in_ctxt,
36 },
37};
38
39struct NestedObligationsForSelfTy<'a, 'db> {
40 ctx: &'a InferenceTable<'db>,
41 self_ty: TyVid,
42 root_cause: &'a ObligationCause,
43 obligations_for_self_ty: &'a mut SmallVec<[Obligation<'db, Predicate<'db>>; 4]>,
44}
45
46impl<'a, 'db> ProofTreeVisitor<'db> for NestedObligationsForSelfTy<'a, 'db> {
47 type Result = ();
48
49 fn config(&self) -> InspectConfig {
50 InspectConfig { max_depth: 5 }
54 }
55
56 fn visit_goal(&mut self, inspect_goal: &InspectGoal<'_, 'db>) {
57 if inspect_goal.result() == Ok(Certainty::Yes) {
60 return;
61 }
62
63 let db = self.ctx.interner();
64 let goal = inspect_goal.goal();
65 if self.ctx.predicate_has_self_ty(goal.predicate, self.self_ty) {
66 self.obligations_for_self_ty.push(Obligation::new(
67 db,
68 self.root_cause.clone(),
69 goal.param_env,
70 goal.predicate,
71 ));
72 }
73
74 if let Some(candidate) = inspect_goal.unique_applicable_candidate() {
79 candidate.visit_nested_no_probe(self)
80 }
81 }
82}
83
84pub fn could_unify<'db>(
91 db: &'db dyn HirDatabase,
92 env: ParamEnvAndCrate<'db>,
93 tys: &Canonical<'db, (Ty<'db>, Ty<'db>)>,
94) -> bool {
95 could_unify_impl(db, env, tys, |ctxt| ctxt.try_evaluate_obligations())
96}
97
98pub fn could_unify_deeply<'db>(
103 db: &'db dyn HirDatabase,
104 env: ParamEnvAndCrate<'db>,
105 tys: &Canonical<'db, (Ty<'db>, Ty<'db>)>,
106) -> bool {
107 could_unify_impl(db, env, tys, |ctxt| ctxt.evaluate_obligations_error_on_ambiguity())
108}
109
110fn could_unify_impl<'db>(
111 db: &'db dyn HirDatabase,
112 env: ParamEnvAndCrate<'db>,
113 tys: &Canonical<'db, (Ty<'db>, Ty<'db>)>,
114 select: for<'a> fn(&mut ObligationCtxt<'a, 'db>) -> Vec<NextSolverError<'db>>,
115) -> bool {
116 let interner = DbInterner::new_with(db, env.krate);
117 let infcx = interner.infer_ctxt().build(TypingMode::PostAnalysis);
118 let cause = ObligationCause::dummy();
119 let at = infcx.at(&cause, env.param_env);
120 let ((ty1_with_vars, ty2_with_vars), _) = infcx.instantiate_canonical(tys);
121 let mut ctxt = ObligationCtxt::new(&infcx);
122 let can_unify = at
123 .eq(ty1_with_vars, ty2_with_vars)
124 .map(|infer_ok| ctxt.register_infer_ok_obligations(infer_ok))
125 .is_ok();
126 can_unify && select(&mut ctxt).is_empty()
127}
128
129#[derive(Clone)]
130pub(crate) struct InferenceTable<'db> {
131 pub(crate) db: &'db dyn HirDatabase,
132 pub(crate) param_env: ParamEnv<'db>,
133 pub(crate) infer_ctxt: InferCtxt<'db>,
134 pub(super) fulfillment_cx: FulfillmentCtxt<'db>,
135 pub(super) diverging_type_vars: FxHashSet<Ty<'db>>,
136}
137
138pub(crate) struct InferenceTableSnapshot<'db> {
139 ctxt_snapshot: CombinedSnapshot,
140 obligations: FulfillmentCtxt<'db>,
141}
142
143impl<'db> InferenceTable<'db> {
144 pub(crate) fn new(
147 db: &'db dyn HirDatabase,
148 trait_env: ParamEnv<'db>,
149 krate: Crate,
150 owner: Option<DefWithBodyId>,
151 ) -> Self {
152 let interner = DbInterner::new_with(db, krate);
153 let typing_mode = match owner {
154 Some(owner) => TypingMode::typeck_for_body(interner, owner.into()),
155 None => TypingMode::PostAnalysis,
157 };
158 let infer_ctxt = interner.infer_ctxt().build(typing_mode);
159 InferenceTable {
160 db,
161 param_env: trait_env,
162 fulfillment_cx: FulfillmentCtxt::new(&infer_ctxt),
163 infer_ctxt,
164 diverging_type_vars: FxHashSet::default(),
165 }
166 }
167
168 #[inline]
169 pub(crate) fn interner(&self) -> DbInterner<'db> {
170 self.infer_ctxt.interner
171 }
172
173 pub(crate) fn type_is_copy_modulo_regions(&self, ty: Ty<'db>) -> bool {
174 self.infer_ctxt.type_is_copy_modulo_regions(self.param_env, ty)
175 }
176
177 pub(crate) fn type_var_is_sized(&self, self_ty: TyVid) -> bool {
178 let Some(sized_did) = self.interner().lang_items().Sized else {
179 return true;
180 };
181 self.obligations_for_self_ty(self_ty).into_iter().any(|obligation| {
182 match obligation.predicate.kind().skip_binder() {
183 PredicateKind::Clause(ClauseKind::Trait(data)) => data.def_id().0 == sized_did,
184 _ => false,
185 }
186 })
187 }
188
189 pub(super) fn obligations_for_self_ty(
190 &self,
191 self_ty: TyVid,
192 ) -> SmallVec<[Obligation<'db, Predicate<'db>>; 4]> {
193 let obligations = self.fulfillment_cx.pending_obligations();
194 let mut obligations_for_self_ty = SmallVec::new();
195 for obligation in obligations {
196 let mut visitor = NestedObligationsForSelfTy {
197 ctx: self,
198 self_ty,
199 obligations_for_self_ty: &mut obligations_for_self_ty,
200 root_cause: &obligation.cause,
201 };
202
203 let goal = obligation.as_goal();
204 self.infer_ctxt.visit_proof_tree(goal, &mut visitor);
205 }
206
207 obligations_for_self_ty.retain_mut(|obligation| {
208 obligation.predicate = self.infer_ctxt.resolve_vars_if_possible(obligation.predicate);
209 !obligation.predicate.has_placeholders()
210 });
211 obligations_for_self_ty
212 }
213
214 fn predicate_has_self_ty(&self, predicate: Predicate<'db>, expected_vid: TyVid) -> bool {
215 match predicate.kind().skip_binder() {
216 PredicateKind::Clause(ClauseKind::Trait(data)) => {
217 self.type_matches_expected_vid(expected_vid, data.self_ty())
218 }
219 PredicateKind::Clause(ClauseKind::Projection(data)) => {
220 self.type_matches_expected_vid(expected_vid, data.projection_term.self_ty())
221 }
222 PredicateKind::Clause(ClauseKind::ConstArgHasType(..))
223 | PredicateKind::Subtype(..)
224 | PredicateKind::Coerce(..)
225 | PredicateKind::Clause(ClauseKind::RegionOutlives(..))
226 | PredicateKind::Clause(ClauseKind::TypeOutlives(..))
227 | PredicateKind::Clause(ClauseKind::WellFormed(..))
228 | PredicateKind::DynCompatible(..)
229 | PredicateKind::NormalizesTo(..)
230 | PredicateKind::AliasRelate(..)
231 | PredicateKind::Clause(ClauseKind::ConstEvaluatable(..))
232 | PredicateKind::ConstEquate(..)
233 | PredicateKind::Clause(ClauseKind::HostEffect(..))
234 | PredicateKind::Clause(ClauseKind::UnstableFeature(_))
235 | PredicateKind::Ambiguous => false,
236 }
237 }
238
239 fn type_matches_expected_vid(&self, expected_vid: TyVid, ty: Ty<'db>) -> bool {
240 let ty = self.shallow_resolve(ty);
241
242 match ty.kind() {
243 TyKind::Infer(rustc_type_ir::TyVar(found_vid)) => {
244 self.infer_ctxt.root_var(expected_vid) == self.infer_ctxt.root_var(found_vid)
245 }
246 _ => false,
247 }
248 }
249
250 pub(super) fn set_diverging(&mut self, ty: Ty<'db>) {
251 self.diverging_type_vars.insert(ty);
252 }
253
254 pub(crate) fn canonicalize<T>(&mut self, t: T) -> rustc_type_ir::Canonical<DbInterner<'db>, T>
255 where
256 T: TypeFoldable<DbInterner<'db>>,
257 {
258 self.select_obligations_where_possible();
261 self.infer_ctxt.canonicalize_response(t)
262 }
263
264 pub(crate) fn normalize_associated_types_in<T>(&mut self, ty: T) -> T
267 where
268 T: TypeFoldable<DbInterner<'db>> + Clone,
269 {
270 let ty = self.resolve_vars_with_obligations(ty);
271 self.at(&ObligationCause::new()).deeply_normalize(ty.clone()).unwrap_or(ty)
272 }
273
274 pub(crate) fn normalize_alias_ty(&mut self, alias: Ty<'db>) -> Ty<'db> {
275 self.infer_ctxt
276 .at(&ObligationCause::new(), self.param_env)
277 .structurally_normalize_ty(alias, &mut self.fulfillment_cx)
278 .unwrap_or(alias)
279 }
280
281 pub(crate) fn next_ty_var(&self) -> Ty<'db> {
282 self.infer_ctxt.next_ty_var()
283 }
284
285 pub(crate) fn next_const_var(&self) -> Const<'db> {
286 self.infer_ctxt.next_const_var()
287 }
288
289 pub(crate) fn next_int_var(&self) -> Ty<'db> {
290 self.infer_ctxt.next_int_var()
291 }
292
293 pub(crate) fn next_float_var(&self) -> Ty<'db> {
294 self.infer_ctxt.next_float_var()
295 }
296
297 pub(crate) fn new_maybe_never_var(&mut self) -> Ty<'db> {
298 let var = self.next_ty_var();
299 self.set_diverging(var);
300 var
301 }
302
303 pub(crate) fn next_region_var(&self) -> Region<'db> {
304 self.infer_ctxt.next_region_var()
305 }
306
307 pub(crate) fn next_var_for_param(&self, id: GenericParamId) -> GenericArg<'db> {
308 self.infer_ctxt.next_var_for_param(id)
309 }
310
311 pub(crate) fn resolve_completely<T>(&mut self, value: T) -> T
312 where
313 T: TypeFoldable<DbInterner<'db>>,
314 {
315 let value = self.infer_ctxt.resolve_vars_if_possible(value);
316
317 let mut goals = vec![];
318
319 value.fold_with(&mut resolve_completely::Resolver::new(self, true, &mut goals))
322 }
323
324 pub(crate) fn unify<T: ToTrace<'db>>(&mut self, ty1: T, ty2: T) -> bool {
326 self.try_unify(ty1, ty2).map(|infer_ok| self.register_infer_ok(infer_ok)).is_ok()
327 }
328
329 pub(crate) fn try_unify<T: ToTrace<'db>>(&mut self, t1: T, t2: T) -> InferResult<'db, ()> {
332 self.at(&ObligationCause::new()).eq(t1, t2)
333 }
334
335 pub(crate) fn at<'a>(&'a self, cause: &'a ObligationCause) -> At<'a, 'db> {
336 self.infer_ctxt.at(cause, self.param_env)
337 }
338
339 pub(crate) fn shallow_resolve(&self, ty: Ty<'db>) -> Ty<'db> {
340 self.infer_ctxt.shallow_resolve(ty)
341 }
342
343 pub(crate) fn resolve_vars_with_obligations<T>(&mut self, t: T) -> T
344 where
345 T: rustc_type_ir::TypeFoldable<DbInterner<'db>>,
346 {
347 if !t.has_non_region_infer() {
348 return t;
349 }
350
351 let t = self.infer_ctxt.resolve_vars_if_possible(t);
352
353 if !t.has_non_region_infer() {
354 return t;
355 }
356
357 self.select_obligations_where_possible();
358 self.infer_ctxt.resolve_vars_if_possible(t)
359 }
360
361 pub(crate) fn fresh_args_for_item(&self, def: SolverDefId) -> GenericArgs<'db> {
363 self.infer_ctxt.fresh_args_for_item(def)
364 }
365
366 pub(crate) fn try_structurally_resolve_type(&mut self, ty: Ty<'db>) -> Ty<'db> {
372 if let TyKind::Alias(..) = ty.kind() {
373 let result = self
377 .infer_ctxt
378 .at(&ObligationCause::misc(), self.param_env)
379 .structurally_normalize_ty(ty, &mut self.fulfillment_cx);
380 match result {
381 Ok(normalized_ty) => normalized_ty,
382 Err(_errors) => Ty::new_error(self.interner(), ErrorGuaranteed),
383 }
384 } else {
385 self.resolve_vars_with_obligations(ty)
386 }
387 }
388
389 pub(crate) fn structurally_resolve_type(&mut self, ty: Ty<'db>) -> Ty<'db> {
390 self.try_structurally_resolve_type(ty)
391 }
393
394 pub(crate) fn snapshot(&mut self) -> InferenceTableSnapshot<'db> {
395 let ctxt_snapshot = self.infer_ctxt.start_snapshot();
396 let obligations = self.fulfillment_cx.clone();
397 InferenceTableSnapshot { ctxt_snapshot, obligations }
398 }
399
400 #[tracing::instrument(skip_all)]
401 pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot<'db>) {
402 self.infer_ctxt.rollback_to(snapshot.ctxt_snapshot);
403 self.fulfillment_cx = snapshot.obligations;
404 }
405
406 pub(crate) fn commit_if_ok<T, E>(
407 &mut self,
408 f: impl FnOnce(&mut InferenceTable<'db>) -> Result<T, E>,
409 ) -> Result<T, E> {
410 let snapshot = self.snapshot();
411 let result = f(self);
412 match result {
413 Ok(_) => {}
414 Err(_) => {
415 self.rollback_to(snapshot);
416 }
417 }
418 result
419 }
420
421 #[tracing::instrument(level = "debug", skip(self))]
425 pub(crate) fn try_obligation(&mut self, predicate: Predicate<'db>) -> NextTraitSolveResult {
426 let goal = Goal { param_env: self.param_env, predicate };
427 let canonicalized = self.canonicalize(goal);
428
429 next_trait_solve_canonical_in_ctxt(&self.infer_ctxt, canonicalized)
430 }
431
432 pub(crate) fn register_obligation(&mut self, predicate: Predicate<'db>) {
433 let goal = Goal { param_env: self.param_env, predicate };
434 self.register_obligation_in_env(goal)
435 }
436
437 #[tracing::instrument(level = "debug", skip(self))]
438 fn register_obligation_in_env(&mut self, goal: Goal<'db, Predicate<'db>>) {
439 let result = next_trait_solve_in_ctxt(&self.infer_ctxt, goal);
440 tracing::debug!(?result);
441 match result {
442 Ok((_, Certainty::Yes)) => {}
443 Err(rustc_type_ir::solve::NoSolution) => {}
444 Ok((_, Certainty::Maybe { .. })) => {
445 self.fulfillment_cx.register_predicate_obligation(
446 &self.infer_ctxt,
447 Obligation::new(
448 self.interner(),
449 ObligationCause::new(),
450 goal.param_env,
451 goal.predicate,
452 ),
453 );
454 }
455 }
456 }
457
458 pub(crate) fn register_infer_ok<T>(&mut self, infer_ok: InferOk<'db, T>) -> T {
459 let InferOk { value, obligations } = infer_ok;
460 self.register_predicates(obligations);
461 value
462 }
463
464 pub(crate) fn select_obligations_where_possible(&mut self) {
465 self.fulfillment_cx.try_evaluate_obligations(&self.infer_ctxt);
466 }
467
468 pub(super) fn register_predicate(&mut self, obligation: PredicateObligation<'db>) {
469 if obligation.has_escaping_bound_vars() {
470 panic!("escaping bound vars in predicate {:?}", obligation);
471 }
472
473 self.fulfillment_cx.register_predicate_obligation(&self.infer_ctxt, obligation);
474 }
475
476 pub(crate) fn register_predicates<I>(&mut self, obligations: I)
477 where
478 I: IntoIterator<Item = PredicateObligation<'db>>,
479 {
480 obligations.into_iter().for_each(|obligation| {
481 self.register_predicate(obligation);
482 });
483 }
484
485 pub(crate) fn register_wf_obligation(&mut self, term: Term<'db>, cause: ObligationCause) {
487 self.register_predicate(Obligation::new(
488 self.interner(),
489 cause,
490 self.param_env,
491 ClauseKind::WellFormed(term),
492 ));
493 }
494
495 pub(crate) fn add_wf_bounds(&mut self, args: GenericArgs<'db>) {
497 for term in args.iter().filter_map(|it| it.as_term()) {
498 self.register_wf_obligation(term, ObligationCause::new());
499 }
500 }
501
502 pub(crate) fn callable_sig(
503 &mut self,
504 ty: Ty<'db>,
505 num_args: usize,
506 ) -> Option<(Option<FnTrait>, Vec<Ty<'db>>, Ty<'db>)> {
507 match ty.callable_sig(self.interner()) {
508 Some(sig) => {
509 let sig = sig.skip_binder();
510 Some((None, sig.inputs_and_output.inputs().to_vec(), sig.output()))
511 }
512 None => {
513 let (f, args_ty, return_ty) = self.callable_sig_from_fn_trait(ty, num_args)?;
514 Some((Some(f), args_ty, return_ty))
515 }
516 }
517 }
518
519 fn callable_sig_from_fn_trait(
520 &mut self,
521 ty: Ty<'db>,
522 num_args: usize,
523 ) -> Option<(FnTrait, Vec<Ty<'db>>, Ty<'db>)> {
524 let lang_items = self.interner().lang_items();
525 for (fn_trait_name, output_assoc_name, subtraits) in [
526 (FnTrait::FnOnce, sym::Output, &[FnTrait::Fn, FnTrait::FnMut][..]),
527 (FnTrait::AsyncFnMut, sym::CallRefFuture, &[FnTrait::AsyncFn]),
528 (FnTrait::AsyncFnOnce, sym::CallOnceFuture, &[]),
529 ] {
530 let fn_trait = fn_trait_name.get_id(lang_items)?;
531 let trait_data = fn_trait.trait_items(self.db);
532 let output_assoc_type =
533 trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?;
534
535 let mut arg_tys = Vec::with_capacity(num_args);
536 let arg_ty = Ty::new_tup_from_iter(
537 self.interner(),
538 std::iter::repeat_with(|| {
539 let ty = self.next_ty_var();
540 arg_tys.push(ty);
541 ty
542 })
543 .take(num_args),
544 );
545 let args = [ty, arg_ty];
546 let trait_ref = TraitRef::new(self.interner(), fn_trait.into(), args);
547
548 let proj_args = self
549 .infer_ctxt
550 .fill_rest_fresh_args(output_assoc_type.into(), args.into_iter().map(Into::into));
551 let projection = Ty::new_alias(
552 self.interner(),
553 rustc_type_ir::AliasTyKind::Projection,
554 AliasTy::new(self.interner(), output_assoc_type.into(), proj_args),
555 );
556
557 let pred = Predicate::upcast_from(trait_ref, self.interner());
558 if !self.try_obligation(pred).no_solution() {
559 self.register_obligation(pred);
560 let return_ty = self.normalize_alias_ty(projection);
561 for &fn_x in subtraits {
562 let fn_x_trait = fn_x.get_id(lang_items)?;
563 let trait_ref = TraitRef::new(self.interner(), fn_x_trait.into(), args);
564 let pred = Predicate::upcast_from(trait_ref, self.interner());
565 if !self.try_obligation(pred).no_solution() {
566 return Some((fn_x, arg_tys, return_ty));
567 }
568 }
569 return Some((fn_trait_name, arg_tys, return_ty));
570 }
571 }
572 None
573 }
574
575 pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T
576 where
577 T: TypeFoldable<DbInterner<'db>>,
578 {
579 self.infer_ctxt.insert_type_vars(ty)
580 }
581
582 pub(super) fn insert_type_vars_shallow(&mut self, ty: Ty<'db>) -> Ty<'db> {
584 if ty.is_ty_error() { self.next_ty_var() } else { ty }
585 }
586
587 pub(crate) fn process_user_written_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
589 self.process_remote_user_written_ty(ty)
590 }
592
593 pub(crate) fn process_remote_user_written_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
596 let ty = self.insert_type_vars(ty);
597 self.try_structurally_resolve_type(ty)
602 }
603
604 pub(super) fn insert_const_vars_shallow(&mut self, c: Const<'db>) -> Const<'db> {
606 if c.is_ct_error() { self.next_const_var() } else { c }
607 }
608
609 pub(crate) fn is_sized(&mut self, ty: Ty<'db>) -> bool {
611 fn short_circuit_trivial_tys(ty: Ty<'_>) -> Option<bool> {
612 match ty.kind() {
613 TyKind::Bool
614 | TyKind::Char
615 | TyKind::Int(_)
616 | TyKind::Uint(_)
617 | TyKind::Float(_)
618 | TyKind::Ref(..)
619 | TyKind::RawPtr(..)
620 | TyKind::Never
621 | TyKind::FnDef(..)
622 | TyKind::Array(..)
623 | TyKind::FnPtr(..) => Some(true),
624 TyKind::Slice(..) | TyKind::Str | TyKind::Dynamic(..) => Some(false),
625 _ => None,
626 }
627 }
628
629 let mut ty = ty;
630 ty = self.try_structurally_resolve_type(ty);
631 if let Some(sized) = short_circuit_trivial_tys(ty) {
632 return sized;
633 }
634
635 {
636 let mut structs = SmallVec::<[_; 8]>::new();
637 while let Some((AdtId::StructId(id), subst)) = ty.as_adt() {
640 let struct_data = id.fields(self.db);
641 if let Some((last_field, _)) = struct_data.fields().iter().next_back() {
642 let last_field_ty = self.db.field_types(id.into())[last_field]
643 .instantiate(self.interner(), subst);
644 if structs.contains(&ty) {
645 return true; }
648 structs.push(ty);
649 ty = last_field_ty;
652 ty = self.try_structurally_resolve_type(ty);
653 if let Some(sized) = short_circuit_trivial_tys(ty) {
654 return sized;
655 }
656 } else {
657 break;
658 };
659 }
660 }
661
662 let Some(sized) = self.interner().lang_items().Sized else {
663 return false;
664 };
665 let sized_pred = Predicate::upcast_from(
666 TraitRef::new(self.interner(), sized.into(), [ty]),
667 self.interner(),
668 );
669 self.try_obligation(sized_pred).certain()
670 }
671}
672
673impl fmt::Debug for InferenceTable<'_> {
674 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
675 f.debug_struct("InferenceTable")
676 .field("name", &self.infer_ctxt.inner.borrow().type_variable_storage)
677 .field("fulfillment_cx", &self.fulfillment_cx)
678 .finish()
679 }
680}
681
682mod resolve_completely {
683 use rustc_type_ir::{DebruijnIndex, Flags, TypeFolder, TypeSuperFoldable};
684
685 use crate::{
686 infer::unify::InferenceTable,
687 next_solver::{
688 Const, DbInterner, Goal, Predicate, Region, Term, Ty,
689 infer::{resolve::ReplaceInferWithError, traits::ObligationCause},
690 normalize::deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals,
691 },
692 };
693
694 pub(super) struct Resolver<'a, 'db> {
695 ctx: &'a mut InferenceTable<'db>,
696 should_normalize: bool,
698 nested_goals: &'a mut Vec<Goal<'db, Predicate<'db>>>,
699 }
700
701 impl<'a, 'db> Resolver<'a, 'db> {
702 pub(super) fn new(
703 ctx: &'a mut InferenceTable<'db>,
704 should_normalize: bool,
705 nested_goals: &'a mut Vec<Goal<'db, Predicate<'db>>>,
706 ) -> Resolver<'a, 'db> {
707 Resolver { ctx, nested_goals, should_normalize }
708 }
709
710 fn handle_term<T>(
711 &mut self,
712 value: T,
713 outer_exclusive_binder: impl FnOnce(T) -> DebruijnIndex,
714 ) -> T
715 where
716 T: Into<Term<'db>> + TypeSuperFoldable<DbInterner<'db>> + Copy,
717 {
718 let value = if self.should_normalize {
719 let cause = ObligationCause::new();
720 let at = self.ctx.at(&cause);
721 let universes = vec![None; outer_exclusive_binder(value).as_usize()];
722 match deeply_normalize_with_skipped_universes_and_ambiguous_coroutine_goals(
723 at, value, universes,
724 ) {
725 Ok((value, goals)) => {
726 self.nested_goals.extend(goals);
727 value
728 }
729 Err(_errors) => {
730 value
732 }
733 }
734 } else {
735 value
736 };
737
738 value.fold_with(&mut ReplaceInferWithError::new(self.ctx.interner()))
739 }
740 }
741
742 impl<'cx, 'db> TypeFolder<DbInterner<'db>> for Resolver<'cx, 'db> {
743 fn cx(&self) -> DbInterner<'db> {
744 self.ctx.interner()
745 }
746
747 fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
748 if r.is_var() { Region::error(self.ctx.interner()) } else { r }
749 }
750
751 fn fold_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
752 self.handle_term(ty, |it| it.outer_exclusive_binder())
753 }
754
755 fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
756 self.handle_term(ct, |it| it.outer_exclusive_binder())
757 }
758
759 fn fold_predicate(&mut self, predicate: Predicate<'db>) -> Predicate<'db> {
760 assert!(
761 !self.should_normalize,
762 "normalizing predicates in writeback is not generally sound"
763 );
764 predicate.super_fold_with(self)
765 }
766 }
767}