1use rustc_type_ir::{
4 BoundVarIndexKind, DebruijnIndex, RegionKind, TypeFoldable, TypeFolder, TypeSuperFoldable,
5 TypeVisitableExt, inherent::IntoKind,
6};
7
8use crate::next_solver::{BoundConst, FxIndexMap};
9
10use super::{
11 Binder, BoundRegion, BoundTy, Const, ConstKind, DbInterner, Predicate, Region, Ty, TyKind,
12};
13
14pub trait BoundVarReplacerDelegate<'db> {
20 fn replace_region(&mut self, br: BoundRegion) -> Region<'db>;
21 fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>;
22 fn replace_const(&mut self, bv: BoundConst) -> Const<'db>;
23}
24
25pub struct FnMutDelegate<'db, 'a> {
29 pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a),
30 pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a),
31 pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a),
32}
33
34impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> {
35 fn replace_region(&mut self, br: BoundRegion) -> Region<'db> {
36 (self.regions)(br)
37 }
38 fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> {
39 (self.types)(bt)
40 }
41 fn replace_const(&mut self, bv: BoundConst) -> Const<'db> {
42 (self.consts)(bv)
43 }
44}
45
46pub(crate) struct BoundVarReplacer<'db, D> {
48 interner: DbInterner<'db>,
49 current_index: DebruijnIndex,
52
53 delegate: D,
54}
55
56impl<'db, D: BoundVarReplacerDelegate<'db>> BoundVarReplacer<'db, D> {
57 pub(crate) fn new(tcx: DbInterner<'db>, delegate: D) -> Self {
58 BoundVarReplacer { interner: tcx, current_index: DebruijnIndex::ZERO, delegate }
59 }
60}
61
62impl<'db, D> TypeFolder<DbInterner<'db>> for BoundVarReplacer<'db, D>
63where
64 D: BoundVarReplacerDelegate<'db>,
65{
66 fn cx(&self) -> DbInterner<'db> {
67 self.interner
68 }
69
70 fn fold_binder<T: TypeFoldable<DbInterner<'db>>>(
71 &mut self,
72 t: Binder<'db, T>,
73 ) -> Binder<'db, T> {
74 self.current_index.shift_in(1);
75 let t = t.super_fold_with(self);
76 self.current_index.shift_out(1);
77 t
78 }
79
80 fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
81 match t.kind() {
82 TyKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_ty)
83 if debruijn == self.current_index =>
84 {
85 let ty = self.delegate.replace_ty(bound_ty);
86 debug_assert!(!ty.has_vars_bound_above(DebruijnIndex::ZERO));
87 rustc_type_ir::shift_vars(self.interner, ty, self.current_index.as_u32())
88 }
89 _ => {
90 if !t.has_vars_bound_at_or_above(self.current_index) {
91 t
92 } else {
93 t.super_fold_with(self)
94 }
95 }
96 }
97 }
98
99 fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
100 match r.kind() {
101 RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn), br)
102 if debruijn == self.current_index =>
103 {
104 let region = self.delegate.replace_region(br);
105 if let RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn1), br) = region.kind()
106 {
107 assert_eq!(debruijn1, DebruijnIndex::ZERO);
112 Region::new_bound(self.interner, debruijn, br)
113 } else {
114 region
115 }
116 }
117 _ => r,
118 }
119 }
120
121 fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
122 match ct.kind() {
123 ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_const)
124 if debruijn == self.current_index =>
125 {
126 let ct = self.delegate.replace_const(bound_const);
127 debug_assert!(!ct.has_vars_bound_above(DebruijnIndex::ZERO));
128 rustc_type_ir::shift_vars(self.interner, ct, self.current_index.as_u32())
129 }
130 _ => ct.super_fold_with(self),
131 }
132 }
133
134 fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
135 if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
136 }
137}
138
139pub fn fold_tys<'db, T: TypeFoldable<DbInterner<'db>>>(
140 interner: DbInterner<'db>,
141 t: T,
142 callback: impl FnMut(Ty<'db>) -> Ty<'db>,
143) -> T {
144 struct Folder<'db, F> {
145 interner: DbInterner<'db>,
146 callback: F,
147 }
148 impl<'db, F: FnMut(Ty<'db>) -> Ty<'db>> TypeFolder<DbInterner<'db>> for Folder<'db, F> {
149 fn cx(&self) -> DbInterner<'db> {
150 self.interner
151 }
152
153 fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
154 let t = t.super_fold_with(self);
155 (self.callback)(t)
156 }
157 }
158
159 t.fold_with(&mut Folder { interner, callback })
160}
161
162impl<'db> DbInterner<'db> {
163 pub fn instantiate_bound_regions<T, F>(
177 self,
178 value: Binder<'db, T>,
179 mut fld_r: F,
180 ) -> (T, FxIndexMap<BoundRegion, Region<'db>>)
181 where
182 F: FnMut(BoundRegion) -> Region<'db>,
183 T: TypeFoldable<DbInterner<'db>>,
184 {
185 let mut region_map = FxIndexMap::default();
186 let real_fld_r = |br: BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br));
187 let value = self.instantiate_bound_regions_uncached(value, real_fld_r);
188 (value, region_map)
189 }
190
191 pub fn instantiate_bound_regions_uncached<T, F>(
192 self,
193 value: Binder<'db, T>,
194 mut replace_regions: F,
195 ) -> T
196 where
197 F: FnMut(BoundRegion) -> Region<'db>,
198 T: TypeFoldable<DbInterner<'db>>,
199 {
200 let value = value.skip_binder();
201 if !value.has_escaping_bound_vars() {
202 value
203 } else {
204 let delegate = FnMutDelegate {
205 regions: &mut replace_regions,
206 types: &mut |b| panic!("unexpected bound ty in binder: {b:?}"),
207 consts: &mut |b| panic!("unexpected bound ct in binder: {b:?}"),
208 };
209 let mut replacer = BoundVarReplacer::new(self, delegate);
210 value.fold_with(&mut replacer)
211 }
212 }
213
214 pub fn instantiate_bound_regions_with_erased<T>(self, value: Binder<'db, T>) -> T
217 where
218 T: TypeFoldable<DbInterner<'db>>,
219 {
220 self.instantiate_bound_regions(value, |_| Region::new_erased(self)).0
221 }
222}