hir_ty/
dyn_compatibility.rs

1//! Compute the dyn-compatibility of a trait
2
3use std::ops::ControlFlow;
4
5use hir_def::hir::generics::LocalTypeOrConstParamId;
6use hir_def::{
7    AssocItemId, ConstId, CrateRootModuleId, FunctionId, GenericDefId, HasModule, TraitId,
8    TypeAliasId, lang_item::LangItem, signatures::TraitFlags,
9};
10use hir_def::{TypeOrConstParamId, TypeParamId};
11use intern::Symbol;
12use rustc_hash::FxHashSet;
13use rustc_type_ir::{
14    AliasTyKind, ClauseKind, PredicatePolarity, TypeSuperVisitable as _, TypeVisitable as _,
15    Upcast, elaborate,
16    inherent::{IntoKind, SliceLike},
17};
18use smallvec::SmallVec;
19
20use crate::{
21    ImplTraitId,
22    db::{HirDatabase, InternedOpaqueTyId},
23    lower_nextsolver::associated_ty_item_bounds,
24    next_solver::{
25        Clause, Clauses, DbInterner, GenericArgs, ParamEnv, SolverDefId, TraitPredicate,
26        TypingMode, infer::DbInternerInferExt, mk_param,
27    },
28    traits::next_trait_solve_in_ctxt,
29};
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub enum DynCompatibilityViolation {
33    SizedSelf,
34    SelfReferential,
35    Method(FunctionId, MethodViolationCode),
36    AssocConst(ConstId),
37    GAT(TypeAliasId),
38    // This doesn't exist in rustc, but added for better visualization
39    HasNonCompatibleSuperTrait(TraitId),
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum MethodViolationCode {
44    StaticMethod,
45    ReferencesSelfInput,
46    ReferencesSelfOutput,
47    ReferencesImplTraitInTrait,
48    AsyncFn,
49    WhereClauseReferencesSelf,
50    Generic,
51    UndispatchableReceiver,
52}
53
54pub fn dyn_compatibility(
55    db: &dyn HirDatabase,
56    trait_: TraitId,
57) -> Option<DynCompatibilityViolation> {
58    let interner = DbInterner::new_with(db, Some(trait_.krate(db)), None);
59    for super_trait in elaborate::supertrait_def_ids(interner, SolverDefId::TraitId(trait_)) {
60        let super_trait = match super_trait {
61            SolverDefId::TraitId(id) => id,
62            _ => unreachable!(),
63        };
64        if let Some(v) = db.dyn_compatibility_of_trait(super_trait) {
65            return if super_trait == trait_ {
66                Some(v)
67            } else {
68                Some(DynCompatibilityViolation::HasNonCompatibleSuperTrait(super_trait))
69            };
70        }
71    }
72
73    None
74}
75
76pub fn dyn_compatibility_with_callback<F>(
77    db: &dyn HirDatabase,
78    trait_: TraitId,
79    cb: &mut F,
80) -> ControlFlow<()>
81where
82    F: FnMut(DynCompatibilityViolation) -> ControlFlow<()>,
83{
84    let interner = DbInterner::new_with(db, Some(trait_.krate(db)), None);
85    for super_trait in elaborate::supertrait_def_ids(interner, SolverDefId::TraitId(trait_)).skip(1)
86    {
87        let super_trait = match super_trait {
88            SolverDefId::TraitId(id) => id,
89            _ => unreachable!(),
90        };
91        if db.dyn_compatibility_of_trait(super_trait).is_some() {
92            cb(DynCompatibilityViolation::HasNonCompatibleSuperTrait(trait_))?;
93        }
94    }
95
96    dyn_compatibility_of_trait_with_callback(db, trait_, cb)
97}
98
99pub fn dyn_compatibility_of_trait_with_callback<F>(
100    db: &dyn HirDatabase,
101    trait_: TraitId,
102    cb: &mut F,
103) -> ControlFlow<()>
104where
105    F: FnMut(DynCompatibilityViolation) -> ControlFlow<()>,
106{
107    // Check whether this has a `Sized` bound
108    if generics_require_sized_self(db, trait_.into()) {
109        cb(DynCompatibilityViolation::SizedSelf)?;
110    }
111
112    // Check if there exist bounds that referencing self
113    if predicates_reference_self(db, trait_) {
114        cb(DynCompatibilityViolation::SelfReferential)?;
115    }
116    if bounds_reference_self(db, trait_) {
117        cb(DynCompatibilityViolation::SelfReferential)?;
118    }
119
120    // rustc checks for non-lifetime binders here, but we don't support HRTB yet
121
122    let trait_data = trait_.trait_items(db);
123    for (_, assoc_item) in &trait_data.items {
124        dyn_compatibility_violation_for_assoc_item(db, trait_, *assoc_item, cb)?;
125    }
126
127    ControlFlow::Continue(())
128}
129
130pub fn dyn_compatibility_of_trait_query(
131    db: &dyn HirDatabase,
132    trait_: TraitId,
133) -> Option<DynCompatibilityViolation> {
134    let mut res = None;
135    _ = dyn_compatibility_of_trait_with_callback(db, trait_, &mut |osv| {
136        res = Some(osv);
137        ControlFlow::Break(())
138    });
139
140    res
141}
142
143pub fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool {
144    let krate = def.module(db).krate();
145    let Some(sized) = LangItem::Sized.resolve_trait(db, krate) else {
146        return false;
147    };
148
149    let interner = DbInterner::new_with(db, Some(krate), None);
150    let predicates = db.generic_predicates_ns(def);
151    elaborate::elaborate(interner, predicates.iter().copied()).any(|pred| {
152        match pred.kind().skip_binder() {
153            ClauseKind::Trait(trait_pred) => {
154                if SolverDefId::TraitId(sized) == trait_pred.def_id()
155                    && let rustc_type_ir::TyKind::Param(param_ty) =
156                        trait_pred.trait_ref.self_ty().kind()
157                    && param_ty.index == 0
158                {
159                    true
160                } else {
161                    false
162                }
163            }
164            _ => false,
165        }
166    })
167}
168
169// rustc gathers all the spans that references `Self` for error rendering,
170// but we don't have good way to render such locations.
171// So, just return single boolean value for existence of such `Self` reference
172fn predicates_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
173    db.generic_predicates_ns(trait_.into())
174        .iter()
175        .any(|pred| predicate_references_self(db, trait_, pred, AllowSelfProjection::No))
176}
177
178// Same as the above, `predicates_reference_self`
179fn bounds_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
180    let trait_data = trait_.trait_items(db);
181    trait_data
182        .items
183        .iter()
184        .filter_map(|(_, it)| match *it {
185            AssocItemId::TypeAliasId(id) => Some(associated_ty_item_bounds(db, id)),
186            _ => None,
187        })
188        .any(|bounds| {
189            bounds.skip_binder().iter().any(|pred| match pred.skip_binder() {
190                rustc_type_ir::ExistentialPredicate::Trait(it) => it.args.iter().any(|arg| {
191                    contains_illegal_self_type_reference(db, trait_, &arg, AllowSelfProjection::Yes)
192                }),
193                rustc_type_ir::ExistentialPredicate::Projection(it) => it.args.iter().any(|arg| {
194                    contains_illegal_self_type_reference(db, trait_, &arg, AllowSelfProjection::Yes)
195                }),
196                rustc_type_ir::ExistentialPredicate::AutoTrait(_) => false,
197            })
198        })
199}
200
201#[derive(Clone, Copy)]
202enum AllowSelfProjection {
203    Yes,
204    No,
205}
206
207fn predicate_references_self<'db>(
208    db: &'db dyn HirDatabase,
209    trait_: TraitId,
210    predicate: &Clause<'db>,
211    allow_self_projection: AllowSelfProjection,
212) -> bool {
213    match predicate.kind().skip_binder() {
214        ClauseKind::Trait(trait_pred) => trait_pred.trait_ref.args.iter().skip(1).any(|arg| {
215            contains_illegal_self_type_reference(db, trait_, &arg, allow_self_projection)
216        }),
217        ClauseKind::Projection(proj_pred) => {
218            proj_pred.projection_term.args.iter().skip(1).any(|arg| {
219                contains_illegal_self_type_reference(db, trait_, &arg, allow_self_projection)
220            })
221        }
222        _ => false,
223    }
224}
225
226fn contains_illegal_self_type_reference<'db, T: rustc_type_ir::TypeVisitable<DbInterner<'db>>>(
227    db: &'db dyn HirDatabase,
228    trait_: TraitId,
229    t: &T,
230    allow_self_projection: AllowSelfProjection,
231) -> bool {
232    struct IllegalSelfTypeVisitor<'db> {
233        db: &'db dyn HirDatabase,
234        trait_: TraitId,
235        super_traits: Option<SmallVec<[TraitId; 4]>>,
236        allow_self_projection: AllowSelfProjection,
237    }
238    impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for IllegalSelfTypeVisitor<'db> {
239        type Result = ControlFlow<()>;
240
241        fn visit_ty(
242            &mut self,
243            ty: <DbInterner<'db> as rustc_type_ir::Interner>::Ty,
244        ) -> Self::Result {
245            let interner = DbInterner::new_with(self.db, None, None);
246            match ty.kind() {
247                rustc_type_ir::TyKind::Param(param) if param.index == 0 => ControlFlow::Break(()),
248                rustc_type_ir::TyKind::Param(_) => ControlFlow::Continue(()),
249                rustc_type_ir::TyKind::Alias(AliasTyKind::Projection, proj) => match self
250                    .allow_self_projection
251                {
252                    AllowSelfProjection::Yes => {
253                        let trait_ = proj.trait_def_id(DbInterner::new_with(self.db, None, None));
254                        let trait_ = match trait_ {
255                            SolverDefId::TraitId(id) => id,
256                            _ => unreachable!(),
257                        };
258                        if self.super_traits.is_none() {
259                            self.super_traits = Some(
260                                elaborate::supertrait_def_ids(
261                                    interner,
262                                    SolverDefId::TraitId(self.trait_),
263                                )
264                                .map(|super_trait| match super_trait {
265                                    SolverDefId::TraitId(id) => id,
266                                    _ => unreachable!(),
267                                })
268                                .collect(),
269                            )
270                        }
271                        if self.super_traits.as_ref().is_some_and(|s| s.contains(&trait_)) {
272                            ControlFlow::Continue(())
273                        } else {
274                            ty.super_visit_with(self)
275                        }
276                    }
277                    AllowSelfProjection::No => ty.super_visit_with(self),
278                },
279                _ => ty.super_visit_with(self),
280            }
281        }
282    }
283
284    let mut visitor =
285        IllegalSelfTypeVisitor { db, trait_, super_traits: None, allow_self_projection };
286    t.visit_with(&mut visitor).is_break()
287}
288
289fn dyn_compatibility_violation_for_assoc_item<F>(
290    db: &dyn HirDatabase,
291    trait_: TraitId,
292    item: AssocItemId,
293    cb: &mut F,
294) -> ControlFlow<()>
295where
296    F: FnMut(DynCompatibilityViolation) -> ControlFlow<()>,
297{
298    // Any item that has a `Self : Sized` requisite is otherwise
299    // exempt from the regulations.
300    if generics_require_sized_self(db, item.into()) {
301        return ControlFlow::Continue(());
302    }
303
304    match item {
305        AssocItemId::ConstId(it) => cb(DynCompatibilityViolation::AssocConst(it)),
306        AssocItemId::FunctionId(it) => {
307            virtual_call_violations_for_method(db, trait_, it, &mut |mvc| {
308                cb(DynCompatibilityViolation::Method(it, mvc))
309            })
310        }
311        AssocItemId::TypeAliasId(it) => {
312            let def_map = CrateRootModuleId::from(trait_.krate(db)).def_map(db);
313            if def_map.is_unstable_feature_enabled(&intern::sym::generic_associated_type_extended) {
314                ControlFlow::Continue(())
315            } else {
316                let generic_params = db.generic_params(item.into());
317                if !generic_params.is_empty() {
318                    cb(DynCompatibilityViolation::GAT(it))
319                } else {
320                    ControlFlow::Continue(())
321                }
322            }
323        }
324    }
325}
326
327fn virtual_call_violations_for_method<F>(
328    db: &dyn HirDatabase,
329    trait_: TraitId,
330    func: FunctionId,
331    cb: &mut F,
332) -> ControlFlow<()>
333where
334    F: FnMut(MethodViolationCode) -> ControlFlow<()>,
335{
336    let func_data = db.function_signature(func);
337    if !func_data.has_self_param() {
338        cb(MethodViolationCode::StaticMethod)?;
339    }
340
341    if func_data.is_async() {
342        cb(MethodViolationCode::AsyncFn)?;
343    }
344
345    let sig = db.callable_item_signature_ns(func.into());
346    if sig
347        .skip_binder()
348        .inputs()
349        .iter()
350        .skip(1)
351        .any(|ty| contains_illegal_self_type_reference(db, trait_, &ty, AllowSelfProjection::Yes))
352    {
353        cb(MethodViolationCode::ReferencesSelfInput)?;
354    }
355
356    if contains_illegal_self_type_reference(
357        db,
358        trait_,
359        &sig.skip_binder().output(),
360        AllowSelfProjection::Yes,
361    ) {
362        cb(MethodViolationCode::ReferencesSelfOutput)?;
363    }
364
365    if !func_data.is_async()
366        && let Some(mvc) = contains_illegal_impl_trait_in_trait(db, &sig)
367    {
368        cb(mvc)?;
369    }
370
371    let generic_params = db.generic_params(func.into());
372    if generic_params.len_type_or_consts() > 0 {
373        cb(MethodViolationCode::Generic)?;
374    }
375
376    if func_data.has_self_param() && !receiver_is_dispatchable(db, trait_, func, &sig) {
377        cb(MethodViolationCode::UndispatchableReceiver)?;
378    }
379
380    let predicates = &*db.generic_predicates_without_parent_ns(func.into());
381    for pred in predicates {
382        let pred = pred.kind().skip_binder();
383
384        if matches!(pred, ClauseKind::TypeOutlives(_)) {
385            continue;
386        }
387
388        // Allow `impl AutoTrait` predicates
389        if let ClauseKind::Trait(TraitPredicate {
390            trait_ref: pred_trait_ref,
391            polarity: PredicatePolarity::Positive,
392        }) = pred
393            && let SolverDefId::TraitId(trait_id) = pred_trait_ref.def_id
394            && let trait_data = db.trait_signature(trait_id)
395            && trait_data.flags.contains(TraitFlags::AUTO)
396            && let rustc_type_ir::TyKind::Param(crate::next_solver::ParamTy { index: 0, .. }) =
397                pred_trait_ref.self_ty().kind()
398        {
399            continue;
400        }
401
402        if contains_illegal_self_type_reference(db, trait_, &pred, AllowSelfProjection::Yes) {
403            cb(MethodViolationCode::WhereClauseReferencesSelf)?;
404            break;
405        }
406    }
407
408    ControlFlow::Continue(())
409}
410
411fn receiver_is_dispatchable<'db>(
412    db: &dyn HirDatabase,
413    trait_: TraitId,
414    func: FunctionId,
415    sig: &crate::next_solver::EarlyBinder<
416        'db,
417        crate::next_solver::Binder<'db, rustc_type_ir::FnSig<DbInterner<'db>>>,
418    >,
419) -> bool {
420    let sig = sig.instantiate_identity();
421
422    let interner: DbInterner<'_> = DbInterner::new_with(db, Some(trait_.krate(db)), None);
423    let self_param_id = TypeParamId::from_unchecked(TypeOrConstParamId {
424        parent: trait_.into(),
425        local_id: LocalTypeOrConstParamId::from_raw(la_arena::RawIdx::from_u32(0)),
426    });
427    let self_param_ty = crate::next_solver::Ty::new(
428        interner,
429        rustc_type_ir::TyKind::Param(crate::next_solver::ParamTy { index: 0, id: self_param_id }),
430    );
431
432    // `self: Self` can't be dispatched on, but this is already considered dyn-compatible
433    // See rustc's comment on https://github.com/rust-lang/rust/blob/3f121b9461cce02a703a0e7e450568849dfaa074/compiler/rustc_trait_selection/src/traits/object_safety.rs#L433-L437
434    if sig.inputs().iter().next().is_some_and(|p| p.skip_binder() == self_param_ty) {
435        return true;
436    }
437
438    let Some(&receiver_ty) = sig.inputs().skip_binder().as_slice().first() else {
439        return false;
440    };
441
442    let krate = func.module(db).krate();
443    let traits = (
444        LangItem::Unsize.resolve_trait(db, krate),
445        LangItem::DispatchFromDyn.resolve_trait(db, krate),
446    );
447    let (Some(unsize_did), Some(dispatch_from_dyn_did)) = traits else {
448        return false;
449    };
450
451    let meta_sized_did = LangItem::MetaSized.resolve_trait(db, krate);
452    let Some(meta_sized_did) = meta_sized_did else {
453        return false;
454    };
455
456    // Type `U`
457    // FIXME: That seems problematic to fake a generic param like that?
458    let unsized_self_ty =
459        crate::next_solver::Ty::new_param(interner, self_param_id, u32::MAX, Symbol::empty());
460    // `Receiver[Self => U]`
461    let unsized_receiver_ty = receiver_for_self_ty(interner, func, receiver_ty, unsized_self_ty);
462
463    let param_env = {
464        let generic_predicates = &*db.generic_predicates_ns(func.into());
465
466        // Self: Unsize<U>
467        let unsize_predicate = crate::next_solver::TraitRef::new(
468            interner,
469            SolverDefId::TraitId(unsize_did),
470            [self_param_ty, unsized_self_ty],
471        );
472
473        // U: Trait<Arg1, ..., ArgN>
474        let trait_def_id = SolverDefId::TraitId(trait_);
475        let args = GenericArgs::for_item(interner, trait_def_id, |name, index, kind, _| {
476            if index == 0 { unsized_self_ty.into() } else { mk_param(interner, index, name, kind) }
477        });
478        let trait_predicate =
479            crate::next_solver::TraitRef::new_from_args(interner, trait_def_id, args);
480
481        let meta_sized_predicate = crate::next_solver::TraitRef::new(
482            interner,
483            SolverDefId::TraitId(meta_sized_did),
484            [unsized_self_ty],
485        );
486
487        ParamEnv {
488            clauses: Clauses::new_from_iter(
489                interner,
490                generic_predicates.iter().copied().chain([
491                    unsize_predicate.upcast(interner),
492                    trait_predicate.upcast(interner),
493                    meta_sized_predicate.upcast(interner),
494                ]),
495            ),
496        }
497    };
498
499    // Receiver: DispatchFromDyn<Receiver[Self => U]>
500    let predicate = crate::next_solver::TraitRef::new(
501        interner,
502        SolverDefId::TraitId(dispatch_from_dyn_did),
503        [receiver_ty, unsized_receiver_ty],
504    );
505    let goal = crate::next_solver::Goal::new(interner, param_env, predicate);
506
507    let infcx = interner.infer_ctxt().build(TypingMode::non_body_analysis());
508    // the receiver is dispatchable iff the obligation holds
509    let res = next_trait_solve_in_ctxt(&infcx, goal);
510    res.map_or(false, |res| matches!(res.1, rustc_type_ir::solve::Certainty::Yes))
511}
512
513fn receiver_for_self_ty<'db>(
514    interner: DbInterner<'db>,
515    func: FunctionId,
516    receiver_ty: crate::next_solver::Ty<'db>,
517    self_ty: crate::next_solver::Ty<'db>,
518) -> crate::next_solver::Ty<'db> {
519    let args = crate::next_solver::GenericArgs::for_item(
520        interner,
521        SolverDefId::FunctionId(func),
522        |name, index, kind, _| {
523            if index == 0 { self_ty.into() } else { mk_param(interner, index, name, kind) }
524        },
525    );
526
527    crate::next_solver::EarlyBinder::bind(receiver_ty).instantiate(interner, args)
528}
529
530fn contains_illegal_impl_trait_in_trait<'db>(
531    db: &'db dyn HirDatabase,
532    sig: &crate::next_solver::EarlyBinder<
533        'db,
534        crate::next_solver::Binder<'db, rustc_type_ir::FnSig<DbInterner<'db>>>,
535    >,
536) -> Option<MethodViolationCode> {
537    struct OpaqueTypeCollector(FxHashSet<InternedOpaqueTyId>);
538
539    impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for OpaqueTypeCollector {
540        type Result = ControlFlow<()>;
541
542        fn visit_ty(
543            &mut self,
544            ty: <DbInterner<'db> as rustc_type_ir::Interner>::Ty,
545        ) -> Self::Result {
546            if let rustc_type_ir::TyKind::Alias(AliasTyKind::Opaque, op) = ty.kind() {
547                let id = match op.def_id {
548                    SolverDefId::InternedOpaqueTyId(id) => id,
549                    _ => unreachable!(),
550                };
551                self.0.insert(id);
552            }
553            ty.super_visit_with(self)
554        }
555    }
556
557    let ret = sig.skip_binder().output();
558    let mut visitor = OpaqueTypeCollector(FxHashSet::default());
559    _ = ret.visit_with(&mut visitor);
560
561    // Since we haven't implemented RPITIT in proper way like rustc yet,
562    // just check whether `ret` contains RPIT for now
563    for opaque_ty in visitor.0 {
564        let impl_trait_id = db.lookup_intern_impl_trait_id(opaque_ty);
565        if matches!(impl_trait_id, ImplTraitId::ReturnTypeImplTrait(..)) {
566            return Some(MethodViolationCode::ReferencesImplTraitInTrait);
567        }
568    }
569
570    None
571}
572
573#[cfg(test)]
574mod tests;