hir_ty/next_solver/infer/snapshot/
undo_log.rs1use std::marker::PhantomData;
4
5use ena::snapshot_vec as sv;
6use ena::undo_log::{Rollback, UndoLogs};
7use ena::unify as ut;
8use rustc_type_ir::FloatVid;
9use rustc_type_ir::IntVid;
10use tracing::debug;
11
12use crate::next_solver::OpaqueTypeKey;
13use crate::next_solver::infer::opaque_types::OpaqueHiddenType;
14use crate::next_solver::infer::unify_key::ConstVidKey;
15use crate::next_solver::infer::unify_key::RegionVidKey;
16use crate::next_solver::infer::{InferCtxtInner, region_constraints, type_variable};
17use crate::traits;
18
19pub struct Snapshot {
20 pub(crate) undo_len: usize,
21}
22
23#[derive(Clone)]
25pub(crate) enum UndoLog<'db> {
26 DuplicateOpaqueType,
27 OpaqueTypes(OpaqueTypeKey<'db>, Option<OpaqueHiddenType<'db>>),
28 TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidEqKey<'db>>>),
29 ConstUnificationTable(sv::UndoLog<ut::Delegate<ConstVidKey<'db>>>),
30 IntUnificationTable(sv::UndoLog<ut::Delegate<IntVid>>),
31 FloatUnificationTable(sv::UndoLog<ut::Delegate<FloatVid>>),
32 RegionConstraintCollector(region_constraints::UndoLog<'db>),
33 RegionUnificationTable(sv::UndoLog<ut::Delegate<RegionVidKey<'db>>>),
34 PushRegionObligation,
35}
36
37macro_rules! impl_from {
38 ($($ctor:ident ($ty:ty),)*) => {
39 $(
40 impl<'db> From<$ty> for UndoLog<'db> {
41 fn from(x: $ty) -> Self {
42 UndoLog::$ctor(x.into())
43 }
44 }
45 )*
46 }
47}
48
49impl_from! {
51 RegionConstraintCollector(region_constraints::UndoLog<'db>),
52
53 TypeVariables(sv::UndoLog<ut::Delegate<type_variable::TyVidEqKey<'db>>>),
54 IntUnificationTable(sv::UndoLog<ut::Delegate<IntVid>>),
55 FloatUnificationTable(sv::UndoLog<ut::Delegate<FloatVid>>),
56
57 ConstUnificationTable(sv::UndoLog<ut::Delegate<ConstVidKey<'db>>>),
58
59 RegionUnificationTable(sv::UndoLog<ut::Delegate<RegionVidKey<'db>>>),
60}
61
62impl<'db> Rollback<UndoLog<'db>> for InferCtxtInner<'db> {
64 fn reverse(&mut self, undo: UndoLog<'db>) {
65 match undo {
66 UndoLog::DuplicateOpaqueType => self.opaque_type_storage.pop_duplicate_entry(),
67 UndoLog::OpaqueTypes(key, idx) => self.opaque_type_storage.remove(key, idx),
68 UndoLog::TypeVariables(undo) => self.type_variable_storage.reverse(undo),
69 UndoLog::ConstUnificationTable(undo) => self.const_unification_storage.reverse(undo),
70 UndoLog::IntUnificationTable(undo) => self.int_unification_storage.reverse(undo),
71 UndoLog::FloatUnificationTable(undo) => self.float_unification_storage.reverse(undo),
72 UndoLog::RegionConstraintCollector(undo) => {
73 self.region_constraint_storage.as_mut().unwrap().reverse(undo)
74 }
75 UndoLog::RegionUnificationTable(undo) => {
76 self.region_constraint_storage.as_mut().unwrap().unification_table.reverse(undo)
77 }
78 UndoLog::PushRegionObligation => {
79 self.region_obligations.pop();
80 }
81 }
82 }
83}
84
85#[derive(Clone, Default)]
88pub(crate) struct InferCtxtUndoLogs<'db> {
89 logs: Vec<UndoLog<'db>>,
90 num_open_snapshots: usize,
91}
92
93impl<'db, T> UndoLogs<T> for InferCtxtUndoLogs<'db>
96where
97 UndoLog<'db>: From<T>,
98{
99 #[inline]
100 fn num_open_snapshots(&self) -> usize {
101 self.num_open_snapshots
102 }
103
104 #[inline]
105 fn push(&mut self, undo: T) {
106 if self.in_snapshot() {
107 self.logs.push(undo.into())
108 }
109 }
110
111 fn clear(&mut self) {
112 self.logs.clear();
113 self.num_open_snapshots = 0;
114 }
115
116 fn extend<J>(&mut self, undos: J)
117 where
118 Self: Sized,
119 J: IntoIterator<Item = T>,
120 {
121 if self.in_snapshot() {
122 self.logs.extend(undos.into_iter().map(UndoLog::from))
123 }
124 }
125}
126
127impl<'db> InferCtxtInner<'db> {
128 pub fn rollback_to(&mut self, snapshot: Snapshot) {
129 debug!("rollback_to({})", snapshot.undo_len);
130 self.undo_log.assert_open_snapshot(&snapshot);
131
132 while self.undo_log.logs.len() > snapshot.undo_len {
133 let undo = self.undo_log.logs.pop().unwrap();
134 self.reverse(undo);
135 }
136
137 self.type_variable_storage.finalize_rollback();
138
139 if self.undo_log.num_open_snapshots == 1 {
140 assert!(snapshot.undo_len == 0);
142 assert!(self.undo_log.logs.is_empty());
143 }
144
145 self.undo_log.num_open_snapshots -= 1;
146 }
147
148 pub fn commit(&mut self, snapshot: Snapshot) {
149 debug!("commit({})", snapshot.undo_len);
150
151 if self.undo_log.num_open_snapshots == 1 {
152 assert!(snapshot.undo_len == 0);
156 self.undo_log.logs.clear();
157 }
158
159 self.undo_log.num_open_snapshots -= 1;
160 }
161}
162
163impl<'db> InferCtxtUndoLogs<'db> {
164 pub(crate) fn start_snapshot(&mut self) -> Snapshot {
165 self.num_open_snapshots += 1;
166 Snapshot { undo_len: self.logs.len() }
167 }
168
169 pub(crate) fn region_constraints_in_snapshot(
170 &self,
171 s: &Snapshot,
172 ) -> impl Iterator<Item = &'_ region_constraints::UndoLog<'db>> + Clone {
173 self.logs[s.undo_len..].iter().filter_map(|log| match log {
174 UndoLog::RegionConstraintCollector(log) => Some(log),
175 _ => None,
176 })
177 }
178
179 fn assert_open_snapshot(&self, snapshot: &Snapshot) {
180 assert!(self.logs.len() >= snapshot.undo_len);
182 assert!(self.num_open_snapshots > 0);
183 }
184}
185
186impl<'db> std::ops::Index<usize> for InferCtxtUndoLogs<'db> {
187 type Output = UndoLog<'db>;
188
189 fn index(&self, key: usize) -> &Self::Output {
190 &self.logs[key]
191 }
192}
193
194impl<'db> std::ops::IndexMut<usize> for InferCtxtUndoLogs<'db> {
195 fn index_mut(&mut self, key: usize) -> &mut Self::Output {
196 &mut self.logs[key]
197 }
198}