hir_ty/next_solver/infer/snapshot/
fudge.rs1use std::ops::Range;
2
3use ena::{
4 snapshot_vec as sv,
5 unify::{self as ut, UnifyKey},
6};
7use rustc_type_ir::{
8 ConstVid, FloatVid, IntVid, RegionKind, RegionVid, TyVid, TypeFoldable, TypeFolder,
9 TypeSuperFoldable, TypeVisitableExt, inherent::IntoKind,
10};
11
12use crate::next_solver::{
13 Const, ConstKind, DbInterner, Region, Ty, TyKind,
14 infer::{
15 InferCtxt, UnificationTable, iter_idx_range,
16 snapshot::VariableLengths,
17 type_variable::TypeVariableOrigin,
18 unify_key::{ConstVariableOrigin, ConstVariableValue, ConstVidKey},
19 },
20};
21
22fn vars_since_snapshot<'db, T>(
23 table: &UnificationTable<'_, 'db, T>,
24 snapshot_var_len: usize,
25) -> Range<T>
26where
27 T: UnifyKey,
28 super::UndoLog<'db>: From<sv::UndoLog<ut::Delegate<T>>>,
29{
30 T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
31}
32
33fn const_vars_since_snapshot<'db>(
34 table: &mut UnificationTable<'_, 'db, ConstVidKey<'db>>,
35 snapshot_var_len: usize,
36) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
37 let range = vars_since_snapshot(table, snapshot_var_len);
38 let range = range.start.vid..range.end.vid;
39
40 (
41 range.clone(),
42 iter_idx_range(range)
43 .map(|index| match table.probe_value(index) {
44 ConstVariableValue::Known { value: _ } => ConstVariableOrigin {},
45 ConstVariableValue::Unknown { origin, universe: _ } => origin,
46 })
47 .collect(),
48 )
49}
50
51impl<'db> InferCtxt<'db> {
52 pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
92 where
93 F: FnOnce() -> Result<T, E>,
94 T: TypeFoldable<DbInterner<'db>>,
95 {
96 let variable_lengths = self.variable_lengths();
97 let (snapshot_vars, value) = self.probe(|_| {
98 let value = f()?;
99 let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
105 Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
106 })?;
107
108 Ok(self.fudge_inference(snapshot_vars, value))
113 }
114
115 fn fudge_inference<T: TypeFoldable<DbInterner<'db>>>(
116 &self,
117 snapshot_vars: SnapshotVarData,
118 value: T,
119 ) -> T {
120 if snapshot_vars.is_empty() {
123 value
124 } else {
125 value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
126 }
127 }
128}
129
130struct SnapshotVarData {
131 region_vars: Range<RegionVid>,
132 type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
133 int_vars: Range<IntVid>,
134 float_vars: Range<FloatVid>,
135 const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
136}
137
138impl SnapshotVarData {
139 fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
140 let mut inner = infcx.inner.borrow_mut();
141 let region_vars = inner
142 .unwrap_region_constraints()
143 .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
144 let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
145 let int_vars =
146 vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
147 let float_vars =
148 vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
149
150 let const_vars = const_vars_since_snapshot(
151 &mut inner.const_unification_table(),
152 vars_pre_snapshot.const_var_len,
153 );
154 SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
155 }
156
157 fn is_empty(&self) -> bool {
158 let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
159 region_vars.is_empty()
160 && type_vars.0.is_empty()
161 && int_vars.is_empty()
162 && float_vars.is_empty()
163 && const_vars.0.is_empty()
164 }
165}
166
167struct InferenceFudger<'a, 'db> {
168 infcx: &'a InferCtxt<'db>,
169 snapshot_vars: SnapshotVarData,
170}
171
172impl<'a, 'db> TypeFolder<DbInterner<'db>> for InferenceFudger<'a, 'db> {
173 fn cx(&self) -> DbInterner<'db> {
174 self.infcx.interner
175 }
176
177 fn fold_ty(&mut self, ty: Ty<'db>) -> Ty<'db> {
178 if let TyKind::Infer(infer_ty) = ty.kind() {
179 match infer_ty {
180 rustc_type_ir::TyVar(vid) => {
181 if self.snapshot_vars.type_vars.0.contains(&vid) {
182 let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
185 let origin = self.snapshot_vars.type_vars.1[idx];
186 self.infcx.next_ty_var_with_origin(origin)
187 } else {
188 debug_assert!(
194 self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
195 );
196 ty
197 }
198 }
199 rustc_type_ir::IntVar(vid) => {
200 if self.snapshot_vars.int_vars.contains(&vid) {
201 self.infcx.next_int_var()
202 } else {
203 ty
204 }
205 }
206 rustc_type_ir::FloatVar(vid) => {
207 if self.snapshot_vars.float_vars.contains(&vid) {
208 self.infcx.next_float_var()
209 } else {
210 ty
211 }
212 }
213 rustc_type_ir::FreshTy(_)
214 | rustc_type_ir::FreshIntTy(_)
215 | rustc_type_ir::FreshFloatTy(_) => {
216 unreachable!("unexpected fresh infcx var")
217 }
218 }
219 } else if ty.has_infer() {
220 ty.super_fold_with(self)
221 } else {
222 ty
223 }
224 }
225
226 fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
227 if let RegionKind::ReVar(vid) = r.kind() {
228 if self.snapshot_vars.region_vars.contains(&vid) {
229 self.infcx.next_region_var()
230 } else {
231 r
232 }
233 } else {
234 r
235 }
236 }
237
238 fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
239 if let ConstKind::Infer(infer_ct) = ct.kind() {
240 match infer_ct {
241 rustc_type_ir::InferConst::Var(vid) => {
242 if self.snapshot_vars.const_vars.0.contains(&vid) {
243 let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
244 let origin = self.snapshot_vars.const_vars.1[idx];
245 self.infcx.next_const_var_with_origin(origin)
246 } else {
247 ct
248 }
249 }
250 rustc_type_ir::InferConst::Fresh(_) => {
251 unreachable!("unexpected fresh infcx var")
252 }
253 }
254 } else if ct.has_infer() {
255 ct.super_fold_with(self)
256 } else {
257 ct
258 }
259 }
260}