1use hir_def::TraitId;
2use rustc_type_ir::{TypeFoldable, Upcast, Variance};
3
4use crate::next_solver::{
5 Const, DbInterner, ParamEnv, Term, TraitRef, Ty, TypeError,
6 fulfill::{FulfillmentCtxt, NextSolverError},
7 infer::{
8 InferCtxt, InferOk,
9 at::ToTrace,
10 traits::{Obligation, ObligationCause, PredicateObligation, PredicateObligations},
11 },
12};
13
14pub struct ObligationCtxt<'a, 'db> {
17 pub infcx: &'a InferCtxt<'db>,
18 engine: FulfillmentCtxt<'db>,
19}
20
21impl<'a, 'db> ObligationCtxt<'a, 'db> {
22 pub fn new(infcx: &'a InferCtxt<'db>) -> Self {
23 Self { infcx, engine: FulfillmentCtxt::new(infcx) }
24 }
25}
26
27impl<'a, 'db> ObligationCtxt<'a, 'db> {
28 pub fn register_obligation(&mut self, obligation: PredicateObligation<'db>) {
29 self.engine.register_predicate_obligation(self.infcx, obligation);
30 }
31
32 pub fn register_obligations(
33 &mut self,
34 obligations: impl IntoIterator<Item = PredicateObligation<'db>>,
35 ) {
36 self.engine.register_predicate_obligations(self.infcx, obligations);
37 }
38
39 pub fn register_infer_ok_obligations<T>(&mut self, infer_ok: InferOk<'db, T>) -> T {
40 let InferOk { value, obligations } = infer_ok;
41 self.register_obligations(obligations);
42 value
43 }
44
45 pub fn register_bound(
49 &mut self,
50 cause: ObligationCause,
51 param_env: ParamEnv<'db>,
52 ty: Ty<'db>,
53 def_id: TraitId,
54 ) {
55 let trait_ref = TraitRef::new(self.infcx.interner, def_id.into(), [ty]);
56 self.register_obligation(Obligation {
57 cause,
58 recursion_depth: 0,
59 param_env,
60 predicate: trait_ref.upcast(self.infcx.interner),
61 });
62 }
63
64 pub fn eq<T: ToTrace<'db>>(
65 &mut self,
66 cause: &ObligationCause,
67 param_env: ParamEnv<'db>,
68 expected: T,
69 actual: T,
70 ) -> Result<(), TypeError<'db>> {
71 self.infcx
72 .at(cause, param_env)
73 .eq(expected, actual)
74 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
75 }
76
77 pub fn sub<T: ToTrace<'db>>(
79 &mut self,
80 cause: &ObligationCause,
81 param_env: ParamEnv<'db>,
82 expected: T,
83 actual: T,
84 ) -> Result<(), TypeError<'db>> {
85 self.infcx
86 .at(cause, param_env)
87 .sub(expected, actual)
88 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
89 }
90
91 pub fn relate<T: ToTrace<'db>>(
92 &mut self,
93 cause: &ObligationCause,
94 param_env: ParamEnv<'db>,
95 variance: Variance,
96 expected: T,
97 actual: T,
98 ) -> Result<(), TypeError<'db>> {
99 self.infcx
100 .at(cause, param_env)
101 .relate(expected, variance, actual)
102 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
103 }
104
105 pub fn sup<T: ToTrace<'db>>(
107 &mut self,
108 cause: &ObligationCause,
109 param_env: ParamEnv<'db>,
110 expected: T,
111 actual: T,
112 ) -> Result<(), TypeError<'db>> {
113 self.infcx
114 .at(cause, param_env)
115 .sup(expected, actual)
116 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
117 }
118
119 pub fn lub<T: ToTrace<'db>>(
121 &mut self,
122 cause: &ObligationCause,
123 param_env: ParamEnv<'db>,
124 expected: T,
125 actual: T,
126 ) -> Result<T, TypeError<'db>> {
127 self.infcx
128 .at(cause, param_env)
129 .lub(expected, actual)
130 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
131 }
132
133 #[must_use]
134 pub fn try_evaluate_obligations(&mut self) -> Vec<NextSolverError<'db>> {
135 self.engine.try_evaluate_obligations(self.infcx)
136 }
137
138 #[must_use]
139 pub fn evaluate_obligations_error_on_ambiguity(&mut self) -> Vec<NextSolverError<'db>> {
140 self.engine.evaluate_obligations_error_on_ambiguity(self.infcx)
141 }
142
143 #[must_use]
151 pub fn into_pending_obligations(self) -> PredicateObligations<'db> {
152 self.engine.pending_obligations()
153 }
154
155 pub fn deeply_normalize<T: TypeFoldable<DbInterner<'db>>>(
156 &self,
157 cause: &ObligationCause,
158 param_env: ParamEnv<'db>,
159 value: T,
160 ) -> Result<T, Vec<NextSolverError<'db>>> {
161 self.infcx.at(cause, param_env).deeply_normalize(value)
162 }
163
164 pub fn structurally_normalize_ty(
165 &mut self,
166 cause: &ObligationCause,
167 param_env: ParamEnv<'db>,
168 value: Ty<'db>,
169 ) -> Result<Ty<'db>, Vec<NextSolverError<'db>>> {
170 self.infcx.at(cause, param_env).structurally_normalize_ty(value, &mut self.engine)
171 }
172
173 pub fn structurally_normalize_const(
174 &mut self,
175 cause: &ObligationCause,
176 param_env: ParamEnv<'db>,
177 value: Const<'db>,
178 ) -> Result<Const<'db>, Vec<NextSolverError<'db>>> {
179 self.infcx.at(cause, param_env).structurally_normalize_const(value, &mut self.engine)
180 }
181
182 pub fn structurally_normalize_term(
183 &mut self,
184 cause: &ObligationCause,
185 param_env: ParamEnv<'db>,
186 value: Term<'db>,
187 ) -> Result<Term<'db>, Vec<NextSolverError<'db>>> {
188 self.infcx.at(cause, param_env).structurally_normalize_term(value, &mut self.engine)
189 }
190}