hir_ty/next_solver/
fold.rs

1//! Fold impls for the next-trait-solver.
2
3use rustc_type_ir::{
4    BoundVarIndexKind, DebruijnIndex, RegionKind, TypeFoldable, TypeFolder, TypeSuperFoldable,
5    TypeVisitableExt, inherent::IntoKind,
6};
7
8use crate::next_solver::{BoundConst, FxIndexMap};
9
10use super::{
11    Binder, BoundRegion, BoundTy, Const, ConstKind, DbInterner, Predicate, Region, Ty, TyKind,
12};
13
14/// A delegate used when instantiating bound vars.
15///
16/// Any implementation must make sure that each bound variable always
17/// gets mapped to the same result. `BoundVarReplacer` caches by using
18/// a `DelayedMap` which does not cache the first few types it encounters.
19pub trait BoundVarReplacerDelegate<'db> {
20    fn replace_region(&mut self, br: BoundRegion) -> Region<'db>;
21    fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>;
22    fn replace_const(&mut self, bv: BoundConst) -> Const<'db>;
23}
24
25/// A simple delegate taking 3 mutable functions. The used functions must
26/// always return the same result for each bound variable, no matter how
27/// frequently they are called.
28pub struct FnMutDelegate<'db, 'a> {
29    pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a),
30    pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a),
31    pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a),
32}
33
34impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> {
35    fn replace_region(&mut self, br: BoundRegion) -> Region<'db> {
36        (self.regions)(br)
37    }
38    fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> {
39        (self.types)(bt)
40    }
41    fn replace_const(&mut self, bv: BoundConst) -> Const<'db> {
42        (self.consts)(bv)
43    }
44}
45
46/// Replaces the escaping bound vars (late bound regions or bound types) in a type.
47pub(crate) struct BoundVarReplacer<'db, D> {
48    interner: DbInterner<'db>,
49    /// As with `RegionFolder`, represents the index of a binder *just outside*
50    /// the ones we have visited.
51    current_index: DebruijnIndex,
52
53    delegate: D,
54}
55
56impl<'db, D: BoundVarReplacerDelegate<'db>> BoundVarReplacer<'db, D> {
57    pub(crate) fn new(tcx: DbInterner<'db>, delegate: D) -> Self {
58        BoundVarReplacer { interner: tcx, current_index: DebruijnIndex::ZERO, delegate }
59    }
60}
61
62impl<'db, D> TypeFolder<DbInterner<'db>> for BoundVarReplacer<'db, D>
63where
64    D: BoundVarReplacerDelegate<'db>,
65{
66    fn cx(&self) -> DbInterner<'db> {
67        self.interner
68    }
69
70    fn fold_binder<T: TypeFoldable<DbInterner<'db>>>(
71        &mut self,
72        t: Binder<'db, T>,
73    ) -> Binder<'db, T> {
74        self.current_index.shift_in(1);
75        let t = t.super_fold_with(self);
76        self.current_index.shift_out(1);
77        t
78    }
79
80    fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
81        match t.kind() {
82            TyKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_ty)
83                if debruijn == self.current_index =>
84            {
85                let ty = self.delegate.replace_ty(bound_ty);
86                debug_assert!(!ty.has_vars_bound_above(DebruijnIndex::ZERO));
87                rustc_type_ir::shift_vars(self.interner, ty, self.current_index.as_u32())
88            }
89            _ => {
90                if !t.has_vars_bound_at_or_above(self.current_index) {
91                    t
92                } else {
93                    t.super_fold_with(self)
94                }
95            }
96        }
97    }
98
99    fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
100        match r.kind() {
101            RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn), br)
102                if debruijn == self.current_index =>
103            {
104                let region = self.delegate.replace_region(br);
105                if let RegionKind::ReBound(BoundVarIndexKind::Bound(debruijn1), br) = region.kind()
106                {
107                    // If the callback returns a bound region,
108                    // that region should always use the INNERMOST
109                    // debruijn index. Then we adjust it to the
110                    // correct depth.
111                    assert_eq!(debruijn1, DebruijnIndex::ZERO);
112                    Region::new_bound(self.interner, debruijn, br)
113                } else {
114                    region
115                }
116            }
117            _ => r,
118        }
119    }
120
121    fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
122        match ct.kind() {
123            ConstKind::Bound(BoundVarIndexKind::Bound(debruijn), bound_const)
124                if debruijn == self.current_index =>
125            {
126                let ct = self.delegate.replace_const(bound_const);
127                debug_assert!(!ct.has_vars_bound_above(DebruijnIndex::ZERO));
128                rustc_type_ir::shift_vars(self.interner, ct, self.current_index.as_u32())
129            }
130            _ => ct.super_fold_with(self),
131        }
132    }
133
134    fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
135        if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
136    }
137}
138
139pub fn fold_tys<'db, T: TypeFoldable<DbInterner<'db>>>(
140    interner: DbInterner<'db>,
141    t: T,
142    callback: impl FnMut(Ty<'db>) -> Ty<'db>,
143) -> T {
144    struct Folder<'db, F> {
145        interner: DbInterner<'db>,
146        callback: F,
147    }
148    impl<'db, F: FnMut(Ty<'db>) -> Ty<'db>> TypeFolder<DbInterner<'db>> for Folder<'db, F> {
149        fn cx(&self) -> DbInterner<'db> {
150            self.interner
151        }
152
153        fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
154            let t = t.super_fold_with(self);
155            (self.callback)(t)
156        }
157    }
158
159    t.fold_with(&mut Folder { interner, callback })
160}
161
162impl<'db> DbInterner<'db> {
163    /// Replaces all regions bound by the given `Binder` with the
164    /// results returned by the closure; the closure is expected to
165    /// return a free region (relative to this binder), and hence the
166    /// binder is removed in the return type. The closure is invoked
167    /// once for each unique `BoundRegionKind`; multiple references to the
168    /// same `BoundRegionKind` will reuse the previous result. A map is
169    /// returned at the end with each bound region and the free region
170    /// that replaced it.
171    ///
172    /// # Panics
173    ///
174    /// This method only replaces late bound regions. Any types or
175    /// constants bound by `value` will cause an ICE.
176    pub fn instantiate_bound_regions<T, F>(
177        self,
178        value: Binder<'db, T>,
179        mut fld_r: F,
180    ) -> (T, FxIndexMap<BoundRegion, Region<'db>>)
181    where
182        F: FnMut(BoundRegion) -> Region<'db>,
183        T: TypeFoldable<DbInterner<'db>>,
184    {
185        let mut region_map = FxIndexMap::default();
186        let real_fld_r = |br: BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br));
187        let value = self.instantiate_bound_regions_uncached(value, real_fld_r);
188        (value, region_map)
189    }
190
191    pub fn instantiate_bound_regions_uncached<T, F>(
192        self,
193        value: Binder<'db, T>,
194        mut replace_regions: F,
195    ) -> T
196    where
197        F: FnMut(BoundRegion) -> Region<'db>,
198        T: TypeFoldable<DbInterner<'db>>,
199    {
200        let value = value.skip_binder();
201        if !value.has_escaping_bound_vars() {
202            value
203        } else {
204            let delegate = FnMutDelegate {
205                regions: &mut replace_regions,
206                types: &mut |b| panic!("unexpected bound ty in binder: {b:?}"),
207                consts: &mut |b| panic!("unexpected bound ct in binder: {b:?}"),
208            };
209            let mut replacer = BoundVarReplacer::new(self, delegate);
210            value.fold_with(&mut replacer)
211        }
212    }
213
214    /// Replaces any late-bound regions bound in `value` with `'erased`. Useful in codegen but also
215    /// method lookup and a few other places where precise region relationships are not required.
216    pub fn instantiate_bound_regions_with_erased<T>(self, value: Binder<'db, T>) -> T
217    where
218        T: TypeFoldable<DbInterner<'db>>,
219    {
220        self.instantiate_bound_regions(value, |_| Region::new_erased(self)).0
221    }
222}