hir_ty/next_solver/
fold.rs1use rustc_type_ir::{
4 BoundVar, DebruijnIndex, RegionKind, TypeFoldable, TypeFolder, TypeSuperFoldable,
5 TypeVisitableExt,
6 inherent::{IntoKind, Region as _},
7};
8
9use crate::next_solver::BoundConst;
10
11use super::{
12 Binder, BoundRegion, BoundTy, Const, ConstKind, DbInterner, Predicate, Region, Ty, TyKind,
13};
14
15pub trait BoundVarReplacerDelegate<'db> {
21 fn replace_region(&mut self, br: BoundRegion) -> Region<'db>;
22 fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db>;
23 fn replace_const(&mut self, bv: BoundConst) -> Const<'db>;
24}
25
26pub struct FnMutDelegate<'db, 'a> {
30 pub regions: &'a mut (dyn FnMut(BoundRegion) -> Region<'db> + 'a),
31 pub types: &'a mut (dyn FnMut(BoundTy) -> Ty<'db> + 'a),
32 pub consts: &'a mut (dyn FnMut(BoundConst) -> Const<'db> + 'a),
33}
34
35impl<'db, 'a> BoundVarReplacerDelegate<'db> for FnMutDelegate<'db, 'a> {
36 fn replace_region(&mut self, br: BoundRegion) -> Region<'db> {
37 (self.regions)(br)
38 }
39 fn replace_ty(&mut self, bt: BoundTy) -> Ty<'db> {
40 (self.types)(bt)
41 }
42 fn replace_const(&mut self, bv: BoundConst) -> Const<'db> {
43 (self.consts)(bv)
44 }
45}
46
47pub(crate) struct BoundVarReplacer<'db, D> {
49 interner: DbInterner<'db>,
50 current_index: DebruijnIndex,
53
54 delegate: D,
55}
56
57impl<'db, D: BoundVarReplacerDelegate<'db>> BoundVarReplacer<'db, D> {
58 pub fn new(tcx: DbInterner<'db>, delegate: D) -> Self {
59 BoundVarReplacer { interner: tcx, current_index: DebruijnIndex::ZERO, delegate }
60 }
61}
62
63impl<'db, D> TypeFolder<DbInterner<'db>> for BoundVarReplacer<'db, D>
64where
65 D: BoundVarReplacerDelegate<'db>,
66{
67 fn cx(&self) -> DbInterner<'db> {
68 self.interner
69 }
70
71 fn fold_binder<T: TypeFoldable<DbInterner<'db>>>(
72 &mut self,
73 t: Binder<'db, T>,
74 ) -> Binder<'db, T> {
75 self.current_index.shift_in(1);
76 let t = t.super_fold_with(self);
77 self.current_index.shift_out(1);
78 t
79 }
80
81 fn fold_ty(&mut self, t: Ty<'db>) -> Ty<'db> {
82 match t.kind() {
83 TyKind::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
84 let ty = self.delegate.replace_ty(bound_ty);
85 debug_assert!(!ty.has_vars_bound_above(DebruijnIndex::ZERO));
86 rustc_type_ir::shift_vars(self.interner, ty, self.current_index.as_u32())
87 }
88 _ => {
89 if !t.has_vars_bound_at_or_above(self.current_index) {
90 t
91 } else {
92 t.super_fold_with(self)
93 }
94 }
95 }
96 }
97
98 fn fold_region(&mut self, r: Region<'db>) -> Region<'db> {
99 match r.kind() {
100 RegionKind::ReBound(debruijn, br) if debruijn == self.current_index => {
101 let region = self.delegate.replace_region(br);
102 if let RegionKind::ReBound(debruijn1, br) = region.kind() {
103 assert_eq!(debruijn1, DebruijnIndex::ZERO);
108 Region::new_bound(self.interner, debruijn, br)
109 } else {
110 region
111 }
112 }
113 _ => r,
114 }
115 }
116
117 fn fold_const(&mut self, ct: Const<'db>) -> Const<'db> {
118 match ct.kind() {
119 ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
120 let ct = self.delegate.replace_const(bound_const);
121 debug_assert!(!ct.has_vars_bound_above(DebruijnIndex::ZERO));
122 rustc_type_ir::shift_vars(self.interner, ct, self.current_index.as_u32())
123 }
124 _ => ct.super_fold_with(self),
125 }
126 }
127
128 fn fold_predicate(&mut self, p: Predicate<'db>) -> Predicate<'db> {
129 if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
130 }
131}