hir_ty/next_solver/
fold.rs

1//! Fold impls for the next-trait-solver.
2
3use rustc_type_ir::{
4    BoundVar, DebruijnIndex, RegionKind, TypeFoldable, TypeFolder, TypeSuperFoldable,
5    TypeVisitableExt,
6    inherent::{IntoKind, Region as _},
7};
8
9use crate::next_solver::BoundConst;
10
11use super::{
12    Binder, BoundRegion, BoundTy, Const, ConstKind, DbInterner, Predicate, Region, Ty, TyKind,
13};
14
15/// A delegate used when instantiating bound vars.
16///
17/// Any implementation must make sure that each bound variable always
18/// gets mapped to the same result. `BoundVarReplacer` caches by using
19/// a `DelayedMap` which does not cache the first few types it encounters.
20pub trait BoundVarReplacerDelegate<'db> {
21    fn replace_region(&mut self, br: BoundRegion) -> Region<'db>;
22    fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>;
23    fn replace_const(&mut self, bv: BoundConst) -> Const<'db>;
24}
25
26/// A simple delegate taking 3 mutable functions. The used functions must
27/// always return the same result for each bound variable, no matter how
28/// frequently they are called.
29pub struct FnMutDelegate<'db, 'a> {
30    pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a),
31    pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a),
32    pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a),
33}
34
35impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> {
36    fn replace_region(&mut self, br: BoundRegion) -> Region<'db> {
37        (self.regions)(br)
38    }
39    fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> {
40        (self.types)(bt)
41    }
42    fn replace_const(&mut self, bv: BoundConst) -> Const<'db> {
43        (self.consts)(bv)
44    }
45}
46
47/// Replaces the escaping bound vars (late bound regions or bound types) in a type.
48pub(crate) struct BoundVarReplacer<'db, D> {
49    interner: DbInterner<'db>,
50    /// As with `RegionFolder`, represents the index of a binder *just outside*
51    /// the ones we have visited.
52    current_index: DebruijnIndex,
53
54    delegate: D,
55}
56
57impl<'db, D: BoundVarReplacerDelegate<'db>> BoundVarReplacer<'db, D> {
58    pub fn new(tcx: DbInterner<'db>, delegate: D) -> Self {
59        BoundVarReplacer { interner: tcx, current_index: DebruijnIndex::ZERO, delegate }
60    }
61}
62
63impl<'db, D> TypeFolder<DbInterner<'db>> for BoundVarReplacer<'db, D>
64where
65    D: BoundVarReplacerDelegate<'db>,
66{
67    fn cx(&self) -> DbInterner<'db> {
68        self.interner
69    }
70
71    fn fold_binder<T: TypeFoldable<DbInterner<'db>>>(
72        &mut self,
73        t: Binder<'db, T>,
74    ) -> Binder<'db, T> {
75        self.current_index.shift_in(1);
76        let t = t.super_fold_with(self);
77        self.current_index.shift_out(1);
78        t
79    }
80
81    fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
82        match t.kind() {
83            TyKind::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
84                let ty = self.delegate.replace_ty(bound_ty);
85                debug_assert!(!ty.has_vars_bound_above(DebruijnIndex::ZERO));
86                rustc_type_ir::shift_vars(self.interner, ty, self.current_index.as_u32())
87            }
88            _ => {
89                if !t.has_vars_bound_at_or_above(self.current_index) {
90                    t
91                } else {
92                    t.super_fold_with(self)
93                }
94            }
95        }
96    }
97
98    fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
99        match r.kind() {
100            RegionKind::ReBound(debruijn, br) if debruijn == self.current_index => {
101                let region = self.delegate.replace_region(br);
102                if let RegionKind::ReBound(debruijn1, br) = region.kind() {
103                    // If the callback returns a bound region,
104                    // that region should always use the INNERMOST
105                    // debruijn index. Then we adjust it to the
106                    // correct depth.
107                    assert_eq!(debruijn1, DebruijnIndex::ZERO);
108                    Region::new_bound(self.interner, debruijn, br)
109                } else {
110                    region
111                }
112            }
113            _ => r,
114        }
115    }
116
117    fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
118        match ct.kind() {
119            ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
120                let ct = self.delegate.replace_const(bound_const);
121                debug_assert!(!ct.has_vars_bound_above(DebruijnIndex::ZERO));
122                rustc_type_ir::shift_vars(self.interner, ct, self.current_index.as_u32())
123            }
124            _ => ct.super_fold_with(self),
125        }
126    }
127
128    fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
129        if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
130    }
131}