hir_ty/next_solver/
fulfill.rs

1//! Fulfill loop for next-solver.
2
3mod errors;
4
5use std::ops::ControlFlow;
6
7use rustc_hash::FxHashSet;
8use rustc_next_trait_solver::{
9    delegate::SolverDelegate,
10    solve::{GoalEvaluation, GoalStalledOn, HasChanged, SolverDelegateEvalExt},
11};
12use rustc_type_ir::{
13    Interner, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
14    inherent::{IntoKind, Span as _},
15    solve::{Certainty, NoSolution},
16};
17
18use crate::next_solver::{
19    DbInterner, SolverContext, SolverDefId, Span, Ty, TyKind, TypingMode,
20    infer::{
21        InferCtxt,
22        traits::{PredicateObligation, PredicateObligations},
23    },
24    inspect::ProofTreeVisitor,
25};
26
27type PendingObligations<'db> =
28    Vec<(PredicateObligation<'db>, Option<GoalStalledOn<DbInterner<'db>>>)>;
29
30/// A trait engine using the new trait solver.
31///
32/// This is mostly identical to how `evaluate_all` works inside of the
33/// solver, except that the requirements are slightly different.
34///
35/// Unlike `evaluate_all` it is possible to add new obligations later on
36/// and we also have to track diagnostics information by using `Obligation`
37/// instead of `Goal`.
38///
39/// It is also likely that we want to use slightly different datastructures
40/// here as this will have to deal with far more root goals than `evaluate_all`.
41#[derive(Debug, Clone)]
42pub struct FulfillmentCtxt<'db> {
43    obligations: ObligationStorage<'db>,
44
45    /// The snapshot in which this context was created. Using the context
46    /// outside of this snapshot leads to subtle bugs if the snapshot
47    /// gets rolled back. Because of this we explicitly check that we only
48    /// use the context in exactly this snapshot.
49    #[expect(unused)]
50    usable_in_snapshot: usize,
51}
52
53#[derive(Default, Debug, Clone)]
54struct ObligationStorage<'db> {
55    /// Obligations which resulted in an overflow in fulfillment itself.
56    ///
57    /// We cannot eagerly return these as error so we instead store them here
58    /// to avoid recomputing them each time `try_evaluate_obligations` is called.
59    /// This also allows us to return the correct `FulfillmentError` for them.
60    overflowed: Vec<PredicateObligation<'db>>,
61    pending: PendingObligations<'db>,
62}
63
64impl<'db> ObligationStorage<'db> {
65    fn register(
66        &mut self,
67        obligation: PredicateObligation<'db>,
68        stalled_on: Option<GoalStalledOn<DbInterner<'db>>>,
69    ) {
70        self.pending.push((obligation, stalled_on));
71    }
72
73    fn clone_pending(&self) -> PredicateObligations<'db> {
74        let mut obligations: PredicateObligations<'db> =
75            self.pending.iter().map(|(o, _)| o.clone()).collect();
76        obligations.extend(self.overflowed.iter().cloned());
77        obligations
78    }
79
80    fn drain_pending<'this, 'cond>(
81        &'this mut self,
82        cond: impl 'cond + Fn(&PredicateObligation<'db>) -> bool,
83    ) -> impl Iterator<Item = (PredicateObligation<'db>, Option<GoalStalledOn<DbInterner<'db>>>)>
84    {
85        self.pending.extract_if(.., move |(o, _)| cond(o))
86    }
87
88    fn on_fulfillment_overflow(&mut self, infcx: &InferCtxt<'db>) {
89        infcx.probe(|_| {
90            // IMPORTANT: we must not use solve any inference variables in the obligations
91            // as this is all happening inside of a probe. We use a probe to make sure
92            // we get all obligations involved in the overflow. We pretty much check: if
93            // we were to do another step of `try_evaluate_obligations`, which goals would
94            // change.
95            // FIXME: <https://github.com/Gankra/thin-vec/pull/66> is merged, this can be removed.
96            self.overflowed.extend(
97                self.pending
98                    .extract_if(.., |(o, stalled_on)| {
99                        let goal = o.as_goal();
100                        let result = <&SolverContext<'db>>::from(infcx).evaluate_root_goal(
101                            goal,
102                            Span::dummy(),
103                            stalled_on.take(),
104                        );
105                        matches!(result, Ok(GoalEvaluation { has_changed: HasChanged::Yes, .. }))
106                    })
107                    .map(|(o, _)| o),
108            );
109        })
110    }
111}
112
113impl<'db> FulfillmentCtxt<'db> {
114    pub fn new(infcx: &InferCtxt<'db>) -> FulfillmentCtxt<'db> {
115        FulfillmentCtxt {
116            obligations: Default::default(),
117            usable_in_snapshot: infcx.num_open_snapshots(),
118        }
119    }
120}
121
122impl<'db> FulfillmentCtxt<'db> {
123    #[tracing::instrument(level = "trace", skip(self, _infcx))]
124    pub(crate) fn register_predicate_obligation(
125        &mut self,
126        _infcx: &InferCtxt<'db>,
127        obligation: PredicateObligation<'db>,
128    ) {
129        // FIXME: See the comment in `try_evaluate_obligations()`.
130        // assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
131        self.obligations.register(obligation, None);
132    }
133
134    pub(crate) fn register_predicate_obligations(
135        &mut self,
136        _infcx: &InferCtxt<'db>,
137        obligations: impl IntoIterator<Item = PredicateObligation<'db>>,
138    ) {
139        // FIXME: See the comment in `try_evaluate_obligations()`.
140        // assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
141        obligations.into_iter().for_each(|obligation| self.obligations.register(obligation, None));
142    }
143
144    pub(crate) fn collect_remaining_errors(
145        &mut self,
146        _infcx: &InferCtxt<'db>,
147    ) -> Vec<NextSolverError<'db>> {
148        self.obligations
149            .pending
150            .drain(..)
151            .map(|(obligation, _)| NextSolverError::Ambiguity(obligation))
152            .chain(self.obligations.overflowed.drain(..).map(NextSolverError::Overflow))
153            .collect()
154    }
155
156    pub(crate) fn try_evaluate_obligations(
157        &mut self,
158        infcx: &InferCtxt<'db>,
159    ) -> Vec<NextSolverError<'db>> {
160        // FIXME(next-solver): We should bring this assertion back. Currently it panics because
161        // there are places which use `InferenceTable` and open a snapshot and register obligations
162        // and select. They should use a different `ObligationCtxt` instead. Then we'll be also able
163        // to not put the obligations queue in `InferenceTable`'s snapshots.
164        // assert_eq!(self.usable_in_snapshot, infcx.num_open_snapshots());
165        let mut errors = Vec::new();
166        let mut obligations = Vec::new();
167        loop {
168            let mut any_changed = false;
169            obligations.extend(self.obligations.drain_pending(|_| true));
170            for (mut obligation, stalled_on) in obligations.drain(..) {
171                if obligation.recursion_depth >= infcx.interner.recursion_limit() {
172                    self.obligations.on_fulfillment_overflow(infcx);
173                    // Only return true errors that we have accumulated while processing.
174                    return errors;
175                }
176
177                let goal = obligation.as_goal();
178                let delegate = <&SolverContext<'db>>::from(infcx);
179                if let Some(certainty) = delegate.compute_goal_fast_path(goal, Span::dummy()) {
180                    match certainty {
181                        Certainty::Yes => {}
182                        Certainty::Maybe { .. } => {
183                            self.obligations.register(obligation, None);
184                        }
185                    }
186                    continue;
187                }
188
189                let result = delegate.evaluate_root_goal(goal, Span::dummy(), stalled_on);
190                let GoalEvaluation { goal: _, certainty, has_changed, stalled_on } = match result {
191                    Ok(result) => result,
192                    Err(NoSolution) => {
193                        errors.push(NextSolverError::TrueError(obligation));
194                        continue;
195                    }
196                };
197
198                if has_changed == HasChanged::Yes {
199                    // We increment the recursion depth here to track the number of times
200                    // this goal has resulted in inference progress. This doesn't precisely
201                    // model the way that we track recursion depth in the old solver due
202                    // to the fact that we only process root obligations, but it is a good
203                    // approximation and should only result in fulfillment overflow in
204                    // pathological cases.
205                    obligation.recursion_depth += 1;
206                    any_changed = true;
207                }
208
209                match certainty {
210                    Certainty::Yes => {}
211                    Certainty::Maybe { .. } => self.obligations.register(obligation, stalled_on),
212                }
213            }
214
215            if !any_changed {
216                break;
217            }
218        }
219
220        errors
221    }
222
223    pub(crate) fn evaluate_obligations_error_on_ambiguity(
224        &mut self,
225        infcx: &InferCtxt<'db>,
226    ) -> Vec<NextSolverError<'db>> {
227        let errors = self.try_evaluate_obligations(infcx);
228        if !errors.is_empty() {
229            return errors;
230        }
231
232        self.collect_remaining_errors(infcx)
233    }
234
235    pub(crate) fn pending_obligations(&self) -> PredicateObligations<'db> {
236        self.obligations.clone_pending()
237    }
238
239    pub(crate) fn drain_stalled_obligations_for_coroutines(
240        &mut self,
241        infcx: &InferCtxt<'db>,
242    ) -> PredicateObligations<'db> {
243        let stalled_coroutines = match infcx.typing_mode() {
244            TypingMode::Analysis { defining_opaque_types_and_generators } => {
245                defining_opaque_types_and_generators
246            }
247            TypingMode::Coherence
248            | TypingMode::Borrowck { defining_opaque_types: _ }
249            | TypingMode::PostBorrowckAnalysis { defined_opaque_types: _ }
250            | TypingMode::PostAnalysis => return Default::default(),
251        };
252        let stalled_coroutines = stalled_coroutines.inner();
253
254        if stalled_coroutines.is_empty() {
255            return Default::default();
256        }
257
258        self.obligations
259            .drain_pending(|obl| {
260                infcx.probe(|_| {
261                    infcx
262                        .visit_proof_tree(
263                            obl.as_goal(),
264                            &mut StalledOnCoroutines {
265                                stalled_coroutines,
266                                cache: Default::default(),
267                            },
268                        )
269                        .is_break()
270                })
271            })
272            .map(|(o, _)| o)
273            .collect()
274    }
275}
276
277/// Detect if a goal is stalled on a coroutine that is owned by the current typeck root.
278///
279/// This function can (erroneously) fail to detect a predicate, i.e. it doesn't need to
280/// be complete. However, this will lead to ambiguity errors, so we want to make it
281/// accurate.
282///
283/// This function can be also return false positives, which will lead to poor diagnostics
284/// so we want to keep this visitor *precise* too.
285pub struct StalledOnCoroutines<'a, 'db> {
286    pub stalled_coroutines: &'a [SolverDefId],
287    pub cache: FxHashSet<Ty<'db>>,
288}
289
290impl<'db> ProofTreeVisitor<'db> for StalledOnCoroutines<'_, 'db> {
291    type Result = ControlFlow<()>;
292
293    fn visit_goal(&mut self, inspect_goal: &super::inspect::InspectGoal<'_, 'db>) -> Self::Result {
294        inspect_goal.goal().predicate.visit_with(self)?;
295
296        if let Some(candidate) = inspect_goal.unique_applicable_candidate() {
297            candidate.visit_nested_no_probe(self)
298        } else {
299            ControlFlow::Continue(())
300        }
301    }
302}
303
304impl<'db> TypeVisitor<DbInterner<'db>> for StalledOnCoroutines<'_, 'db> {
305    type Result = ControlFlow<()>;
306
307    fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result {
308        if !self.cache.insert(ty) {
309            return ControlFlow::Continue(());
310        }
311
312        if let TyKind::Coroutine(def_id, _) = ty.kind()
313            && self.stalled_coroutines.contains(&def_id.into())
314        {
315            ControlFlow::Break(())
316        } else if ty.has_coroutines() {
317            ty.super_visit_with(self)
318        } else {
319            ControlFlow::Continue(())
320        }
321    }
322}
323
324#[derive(Debug)]
325pub enum NextSolverError<'db> {
326    TrueError(PredicateObligation<'db>),
327    Ambiguity(PredicateObligation<'db>),
328    Overflow(PredicateObligation<'db>),
329}
330
331impl NextSolverError<'_> {
332    #[inline]
333    pub fn is_true_error(&self) -> bool {
334        matches!(self, NextSolverError::TrueError(_))
335    }
336}