hir_ty/
traits.rs

1//! Trait solving using next trait solver.
2
3use std::hash::Hash;
4
5use base_db::Crate;
6use hir_def::{
7    AdtId, AssocItemId, HasModule, ImplId, Lookup, TraitId,
8    lang_item::LangItems,
9    nameres::DefMap,
10    signatures::{ConstFlags, EnumFlags, FnFlags, StructFlags, TraitFlags, TypeAliasFlags},
11};
12use hir_expand::name::Name;
13use intern::sym;
14use rustc_next_trait_solver::solve::{HasChanged, SolverDelegateEvalExt};
15use rustc_type_ir::{
16    TypingMode,
17    inherent::{AdtDef, BoundExistentialPredicates, IntoKind, Span as _},
18    solve::Certainty,
19};
20
21use crate::{
22    db::HirDatabase,
23    next_solver::{
24        Canonical, DbInterner, GenericArgs, Goal, ParamEnv, Predicate, SolverContext, Span, Ty,
25        TyKind,
26        infer::{DbInternerInferExt, InferCtxt, traits::ObligationCause},
27        obligation_ctxt::ObligationCtxt,
28    },
29};
30
31/// Type for `hir`, because commonly we want both param env and a crate in an exported API.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct ParamEnvAndCrate<'db> {
34    pub param_env: ParamEnv<'db>,
35    pub krate: Crate,
36}
37
38/// This should be used in `hir` only.
39pub fn structurally_normalize_ty<'db>(
40    infcx: &InferCtxt<'db>,
41    ty: Ty<'db>,
42    env: ParamEnv<'db>,
43) -> Ty<'db> {
44    let TyKind::Alias(..) = ty.kind() else { return ty };
45    let mut ocx = ObligationCtxt::new(infcx);
46    let ty = ocx.structurally_normalize_ty(&ObligationCause::dummy(), env, ty).unwrap_or(ty);
47    ty.replace_infer_with_error(infcx.interner)
48}
49
50#[derive(Clone, Debug, PartialEq)]
51pub enum NextTraitSolveResult {
52    Certain,
53    Uncertain,
54    NoSolution,
55}
56
57impl NextTraitSolveResult {
58    pub fn no_solution(&self) -> bool {
59        matches!(self, NextTraitSolveResult::NoSolution)
60    }
61
62    pub fn certain(&self) -> bool {
63        matches!(self, NextTraitSolveResult::Certain)
64    }
65
66    pub fn uncertain(&self) -> bool {
67        matches!(self, NextTraitSolveResult::Uncertain)
68    }
69}
70
71pub fn next_trait_solve_canonical_in_ctxt<'db>(
72    infer_ctxt: &InferCtxt<'db>,
73    goal: Canonical<'db, Goal<'db, Predicate<'db>>>,
74) -> NextTraitSolveResult {
75    infer_ctxt.probe(|_| {
76        let context = <&SolverContext<'db>>::from(infer_ctxt);
77
78        tracing::info!(?goal);
79
80        let (goal, var_values) = context.instantiate_canonical(&goal);
81        tracing::info!(?var_values);
82
83        let res = context.evaluate_root_goal(goal, Span::dummy(), None);
84
85        let res = res.map(|r| (r.has_changed, r.certainty));
86
87        tracing::debug!("solve_nextsolver({:?}) => {:?}", goal, res);
88
89        match res {
90            Err(_) => NextTraitSolveResult::NoSolution,
91            Ok((_, Certainty::Yes)) => NextTraitSolveResult::Certain,
92            Ok((_, Certainty::Maybe { .. })) => NextTraitSolveResult::Uncertain,
93        }
94    })
95}
96
97/// Solve a trait goal using next trait solver.
98pub fn next_trait_solve_in_ctxt<'db, 'a>(
99    infer_ctxt: &'a InferCtxt<'db>,
100    goal: Goal<'db, Predicate<'db>>,
101) -> Result<(HasChanged, Certainty), rustc_type_ir::solve::NoSolution> {
102    tracing::info!(?goal);
103
104    let context = <&SolverContext<'db>>::from(infer_ctxt);
105
106    let res = context.evaluate_root_goal(goal, Span::dummy(), None);
107
108    let res = res.map(|r| (r.has_changed, r.certainty));
109
110    tracing::debug!("solve_nextsolver({:?}) => {:?}", goal, res);
111
112    res
113}
114
115#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, salsa::Update)]
116pub enum FnTrait {
117    // Warning: Order is important. If something implements `x` it should also implement
118    // `y` if `y <= x`.
119    FnOnce,
120    FnMut,
121    Fn,
122
123    AsyncFnOnce,
124    AsyncFnMut,
125    AsyncFn,
126}
127
128impl FnTrait {
129    pub fn method_name(self) -> Name {
130        match self {
131            FnTrait::FnOnce => Name::new_symbol_root(sym::call_once),
132            FnTrait::FnMut => Name::new_symbol_root(sym::call_mut),
133            FnTrait::Fn => Name::new_symbol_root(sym::call),
134            FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once),
135            FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut),
136            FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call),
137        }
138    }
139
140    pub fn get_id(self, lang_items: &LangItems) -> Option<TraitId> {
141        match self {
142            FnTrait::FnOnce => lang_items.FnOnce,
143            FnTrait::FnMut => lang_items.FnMut,
144            FnTrait::Fn => lang_items.Fn,
145            FnTrait::AsyncFnOnce => lang_items.AsyncFnOnce,
146            FnTrait::AsyncFnMut => lang_items.AsyncFnMut,
147            FnTrait::AsyncFn => lang_items.AsyncFn,
148        }
149    }
150}
151
152/// This should not be used in `hir-ty`, only in `hir`.
153pub fn implements_trait_unique<'db>(
154    ty: Ty<'db>,
155    db: &'db dyn HirDatabase,
156    env: ParamEnvAndCrate<'db>,
157    trait_: TraitId,
158) -> bool {
159    implements_trait_unique_impl(db, env, trait_, &mut |infcx| {
160        infcx.fill_rest_fresh_args(trait_.into(), [ty.into()])
161    })
162}
163
164/// This should not be used in `hir-ty`, only in `hir`.
165pub fn implements_trait_unique_with_args<'db>(
166    db: &'db dyn HirDatabase,
167    env: ParamEnvAndCrate<'db>,
168    trait_: TraitId,
169    args: GenericArgs<'db>,
170) -> bool {
171    implements_trait_unique_impl(db, env, trait_, &mut |_| args)
172}
173
174fn implements_trait_unique_impl<'db>(
175    db: &'db dyn HirDatabase,
176    env: ParamEnvAndCrate<'db>,
177    trait_: TraitId,
178    create_args: &mut dyn FnMut(&InferCtxt<'db>) -> GenericArgs<'db>,
179) -> bool {
180    let interner = DbInterner::new_with(db, env.krate);
181    // FIXME(next-solver): I believe this should be `PostAnalysis`.
182    let infcx = interner.infer_ctxt().build(TypingMode::non_body_analysis());
183
184    let args = create_args(&infcx);
185    let trait_ref = rustc_type_ir::TraitRef::new_from_args(interner, trait_.into(), args);
186    let goal = Goal::new(interner, env.param_env, trait_ref);
187
188    let result = crate::traits::next_trait_solve_in_ctxt(&infcx, goal);
189    matches!(result, Ok((_, Certainty::Yes)))
190}
191
192pub fn is_inherent_impl_coherent(db: &dyn HirDatabase, def_map: &DefMap, impl_id: ImplId) -> bool {
193    let self_ty = db.impl_self_ty(impl_id).instantiate_identity();
194    let self_ty = self_ty.kind();
195    let impl_allowed = match self_ty {
196        TyKind::Tuple(_)
197        | TyKind::FnDef(_, _)
198        | TyKind::Array(_, _)
199        | TyKind::Never
200        | TyKind::RawPtr(_, _)
201        | TyKind::Ref(_, _, _)
202        | TyKind::Slice(_)
203        | TyKind::Str
204        | TyKind::Bool
205        | TyKind::Char
206        | TyKind::Int(_)
207        | TyKind::Uint(_)
208        | TyKind::Float(_) => def_map.is_rustc_coherence_is_core(),
209
210        TyKind::Adt(adt_def, _) => adt_def.def_id().0.module(db).krate(db) == def_map.krate(),
211        TyKind::Dynamic(it, _) => it
212            .principal_def_id()
213            .is_some_and(|trait_id| trait_id.0.module(db).krate(db) == def_map.krate()),
214
215        _ => true,
216    };
217    impl_allowed || {
218        let rustc_has_incoherent_inherent_impls = match self_ty {
219            TyKind::Tuple(_)
220            | TyKind::FnDef(_, _)
221            | TyKind::Array(_, _)
222            | TyKind::Never
223            | TyKind::RawPtr(_, _)
224            | TyKind::Ref(_, _, _)
225            | TyKind::Slice(_)
226            | TyKind::Str
227            | TyKind::Bool
228            | TyKind::Char
229            | TyKind::Int(_)
230            | TyKind::Uint(_)
231            | TyKind::Float(_) => true,
232
233            TyKind::Adt(adt_def, _) => match adt_def.def_id().0 {
234                hir_def::AdtId::StructId(id) => db
235                    .struct_signature(id)
236                    .flags
237                    .contains(StructFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS),
238                hir_def::AdtId::UnionId(id) => db
239                    .union_signature(id)
240                    .flags
241                    .contains(StructFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS),
242                hir_def::AdtId::EnumId(it) => db
243                    .enum_signature(it)
244                    .flags
245                    .contains(EnumFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS),
246            },
247            TyKind::Dynamic(it, _) => it.principal_def_id().is_some_and(|trait_id| {
248                db.trait_signature(trait_id.0)
249                    .flags
250                    .contains(TraitFlags::RUSTC_HAS_INCOHERENT_INHERENT_IMPLS)
251            }),
252
253            _ => false,
254        };
255        let items = impl_id.impl_items(db);
256        rustc_has_incoherent_inherent_impls
257            && !items.items.is_empty()
258            && items.items.iter().all(|&(_, assoc)| match assoc {
259                AssocItemId::FunctionId(it) => {
260                    db.function_signature(it).flags.contains(FnFlags::RUSTC_ALLOW_INCOHERENT_IMPL)
261                }
262                AssocItemId::ConstId(it) => {
263                    db.const_signature(it).flags.contains(ConstFlags::RUSTC_ALLOW_INCOHERENT_IMPL)
264                }
265                AssocItemId::TypeAliasId(it) => db
266                    .type_alias_signature(it)
267                    .flags
268                    .contains(TypeAliasFlags::RUSTC_ALLOW_INCOHERENT_IMPL),
269            })
270    }
271}
272
273/// Checks whether the impl satisfies the orphan rules.
274///
275/// Given `impl<P1..=Pn> Trait<T1..=Tn> for T0`, an `impl`` is valid only if at least one of the following is true:
276/// - Trait is a local trait
277/// - All of
278///   - At least one of the types `T0..=Tn`` must be a local type. Let `Ti`` be the first such type.
279///   - No uncovered type parameters `P1..=Pn` may appear in `T0..Ti`` (excluding `Ti`)
280pub fn check_orphan_rules<'db>(db: &'db dyn HirDatabase, impl_: ImplId) -> bool {
281    let Some(impl_trait) = db.impl_trait(impl_) else {
282        // not a trait impl
283        return true;
284    };
285
286    let local_crate = impl_.lookup(db).container.krate(db);
287    let is_local = |tgt_crate| tgt_crate == local_crate;
288
289    let trait_ref = impl_trait.instantiate_identity();
290    let trait_id = trait_ref.def_id.0;
291    if is_local(trait_id.module(db).krate(db)) {
292        // trait to be implemented is local
293        return true;
294    }
295
296    let unwrap_fundamental = |mut ty: Ty<'db>| {
297        // Unwrap all layers of fundamental types with a loop.
298        loop {
299            match ty.kind() {
300                TyKind::Ref(_, referenced, _) => ty = referenced,
301                TyKind::Adt(adt_def, subs) => {
302                    let AdtId::StructId(s) = adt_def.def_id().0 else {
303                        break ty;
304                    };
305                    let struct_signature = db.struct_signature(s);
306                    if struct_signature.flags.contains(StructFlags::FUNDAMENTAL) {
307                        let next = subs.types().next();
308                        match next {
309                            Some(it) => ty = it,
310                            None => break ty,
311                        }
312                    } else {
313                        break ty;
314                    }
315                }
316                _ => break ty,
317            }
318        }
319    };
320    //   - At least one of the types `T0..=Tn`` must be a local type. Let `Ti`` be the first such type.
321
322    // FIXME: param coverage
323    //   - No uncovered type parameters `P1..=Pn` may appear in `T0..Ti`` (excluding `Ti`)
324    let is_not_orphan = trait_ref.args.types().any(|ty| match unwrap_fundamental(ty).kind() {
325        TyKind::Adt(adt_def, _) => is_local(adt_def.def_id().0.module(db).krate(db)),
326        TyKind::Error(_) => true,
327        TyKind::Dynamic(it, _) => {
328            it.principal_def_id().is_some_and(|trait_id| is_local(trait_id.0.module(db).krate(db)))
329        }
330        _ => false,
331    });
332    #[allow(clippy::let_and_return)]
333    is_not_orphan
334}