use std::ops::ControlFlow;
use chalk_ir::{
cast::Cast,
visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor},
DebruijnIndex,
};
use chalk_solve::rust_ir::InlineBound;
use hir_def::{
data::TraitFlags, lang_item::LangItem, AssocItemId, ConstId, FunctionId, GenericDefId,
HasModule, TraitId, TypeAliasId,
};
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use crate::{
all_super_traits,
db::HirDatabase,
from_assoc_type_id, from_chalk_trait_id,
generics::{generics, trait_self_param_idx},
lower::callable_item_sig,
to_assoc_type_id, to_chalk_trait_id,
utils::elaborate_clause_supertraits,
AliasEq, AliasTy, Binders, BoundVar, CallableSig, GoalData, ImplTraitId, Interner, OpaqueTyId,
ProjectionTyExt, Solution, Substitution, TraitRef, Ty, TyKind, WhereClause,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DynCompatibilityViolation {
SizedSelf,
SelfReferential,
Method(FunctionId, MethodViolationCode),
AssocConst(ConstId),
GAT(TypeAliasId),
HasNonCompatibleSuperTrait(TraitId),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MethodViolationCode {
StaticMethod,
ReferencesSelfInput,
ReferencesSelfOutput,
ReferencesImplTraitInTrait,
AsyncFn,
WhereClauseReferencesSelf,
Generic,
UndispatchableReceiver,
}
pub fn dyn_compatibility(
db: &dyn HirDatabase,
trait_: TraitId,
) -> Option<DynCompatibilityViolation> {
for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1).rev() {
if db.dyn_compatibility_of_trait(super_trait).is_some() {
return Some(DynCompatibilityViolation::HasNonCompatibleSuperTrait(super_trait));
}
}
db.dyn_compatibility_of_trait(trait_)
}
pub fn dyn_compatibility_with_callback<F>(
db: &dyn HirDatabase,
trait_: TraitId,
cb: &mut F,
) -> ControlFlow<()>
where
F: FnMut(DynCompatibilityViolation) -> ControlFlow<()>,
{
for super_trait in all_super_traits(db.upcast(), trait_).into_iter().skip(1).rev() {
if db.dyn_compatibility_of_trait(super_trait).is_some() {
cb(DynCompatibilityViolation::HasNonCompatibleSuperTrait(trait_))?;
}
}
dyn_compatibility_of_trait_with_callback(db, trait_, cb)
}
pub fn dyn_compatibility_of_trait_with_callback<F>(
db: &dyn HirDatabase,
trait_: TraitId,
cb: &mut F,
) -> ControlFlow<()>
where
F: FnMut(DynCompatibilityViolation) -> ControlFlow<()>,
{
if generics_require_sized_self(db, trait_.into()) {
cb(DynCompatibilityViolation::SizedSelf)?;
}
if predicates_reference_self(db, trait_) {
cb(DynCompatibilityViolation::SelfReferential)?;
}
if bounds_reference_self(db, trait_) {
cb(DynCompatibilityViolation::SelfReferential)?;
}
let trait_data = db.trait_data(trait_);
for (_, assoc_item) in &trait_data.items {
dyn_compatibility_violation_for_assoc_item(db, trait_, *assoc_item, cb)?;
}
ControlFlow::Continue(())
}
pub fn dyn_compatibility_of_trait_query(
db: &dyn HirDatabase,
trait_: TraitId,
) -> Option<DynCompatibilityViolation> {
let mut res = None;
dyn_compatibility_of_trait_with_callback(db, trait_, &mut |osv| {
res = Some(osv);
ControlFlow::Break(())
});
res
}
fn generics_require_sized_self(db: &dyn HirDatabase, def: GenericDefId) -> bool {
let krate = def.module(db.upcast()).krate();
let Some(sized) = db.lang_item(krate, LangItem::Sized).and_then(|l| l.as_trait()) else {
return false;
};
let Some(trait_self_param_idx) = trait_self_param_idx(db.upcast(), def) else {
return false;
};
let predicates = &*db.generic_predicates(def);
let predicates = predicates.iter().map(|p| p.skip_binders().skip_binders().clone());
elaborate_clause_supertraits(db, predicates).any(|pred| match pred {
WhereClause::Implemented(trait_ref) => {
if from_chalk_trait_id(trait_ref.trait_id) == sized {
if let TyKind::BoundVar(it) =
*trait_ref.self_type_parameter(Interner).kind(Interner)
{
return it
.index_if_bound_at(DebruijnIndex::ONE)
.is_some_and(|idx| idx == trait_self_param_idx);
}
}
false
}
_ => false,
})
}
fn predicates_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
db.generic_predicates(trait_.into())
.iter()
.any(|pred| predicate_references_self(db, trait_, pred, AllowSelfProjection::No))
}
fn bounds_reference_self(db: &dyn HirDatabase, trait_: TraitId) -> bool {
let trait_data = db.trait_data(trait_);
trait_data
.items
.iter()
.filter_map(|(_, it)| match *it {
AssocItemId::TypeAliasId(id) => {
let assoc_ty_id = to_assoc_type_id(id);
let assoc_ty_data = db.associated_ty_data(assoc_ty_id);
Some(assoc_ty_data)
}
_ => None,
})
.any(|assoc_ty_data| {
assoc_ty_data.binders.skip_binders().bounds.iter().any(|bound| {
let def = from_assoc_type_id(assoc_ty_data.id).into();
match bound.skip_binders() {
InlineBound::TraitBound(it) => it.args_no_self.iter().any(|arg| {
contains_illegal_self_type_reference(
db,
def,
trait_,
arg,
DebruijnIndex::ONE,
AllowSelfProjection::Yes,
)
}),
InlineBound::AliasEqBound(it) => it.parameters.iter().any(|arg| {
contains_illegal_self_type_reference(
db,
def,
trait_,
arg,
DebruijnIndex::ONE,
AllowSelfProjection::Yes,
)
}),
}
})
})
}
#[derive(Clone, Copy)]
enum AllowSelfProjection {
Yes,
No,
}
fn predicate_references_self(
db: &dyn HirDatabase,
trait_: TraitId,
predicate: &Binders<Binders<WhereClause>>,
allow_self_projection: AllowSelfProjection,
) -> bool {
match predicate.skip_binders().skip_binders() {
WhereClause::Implemented(trait_ref) => {
trait_ref.substitution.iter(Interner).skip(1).any(|arg| {
contains_illegal_self_type_reference(
db,
trait_.into(),
trait_,
arg,
DebruijnIndex::ONE,
allow_self_projection,
)
})
}
WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(proj), .. }) => {
proj.substitution.iter(Interner).skip(1).any(|arg| {
contains_illegal_self_type_reference(
db,
trait_.into(),
trait_,
arg,
DebruijnIndex::ONE,
allow_self_projection,
)
})
}
_ => false,
}
}
fn contains_illegal_self_type_reference<T: TypeVisitable<Interner>>(
db: &dyn HirDatabase,
def: GenericDefId,
trait_: TraitId,
t: &T,
outer_binder: DebruijnIndex,
allow_self_projection: AllowSelfProjection,
) -> bool {
let Some(trait_self_param_idx) = trait_self_param_idx(db.upcast(), def) else {
return false;
};
struct IllegalSelfTypeVisitor<'a> {
db: &'a dyn HirDatabase,
trait_: TraitId,
super_traits: Option<SmallVec<[TraitId; 4]>>,
trait_self_param_idx: usize,
allow_self_projection: AllowSelfProjection,
}
impl TypeVisitor<Interner> for IllegalSelfTypeVisitor<'_> {
type BreakTy = ();
fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
self
}
fn interner(&self) -> Interner {
Interner
}
fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> {
match ty.kind(Interner) {
TyKind::BoundVar(BoundVar { debruijn, index }) => {
if *debruijn == outer_binder && *index == self.trait_self_param_idx {
ControlFlow::Break(())
} else {
ty.super_visit_with(self.as_dyn(), outer_binder)
}
}
TyKind::Alias(AliasTy::Projection(proj)) => match self.allow_self_projection {
AllowSelfProjection::Yes => {
let trait_ = proj.trait_(self.db);
if self.super_traits.is_none() {
self.super_traits =
Some(all_super_traits(self.db.upcast(), self.trait_));
}
if self.super_traits.as_ref().is_some_and(|s| s.contains(&trait_)) {
ControlFlow::Continue(())
} else {
ty.super_visit_with(self.as_dyn(), outer_binder)
}
}
AllowSelfProjection::No => ty.super_visit_with(self.as_dyn(), outer_binder),
},
_ => ty.super_visit_with(self.as_dyn(), outer_binder),
}
}
fn visit_const(
&mut self,
constant: &chalk_ir::Const<Interner>,
outer_binder: DebruijnIndex,
) -> std::ops::ControlFlow<Self::BreakTy> {
constant.data(Interner).ty.super_visit_with(self.as_dyn(), outer_binder)
}
}
let mut visitor = IllegalSelfTypeVisitor {
db,
trait_,
super_traits: None,
trait_self_param_idx,
allow_self_projection,
};
t.visit_with(visitor.as_dyn(), outer_binder).is_break()
}
fn dyn_compatibility_violation_for_assoc_item<F>(
db: &dyn HirDatabase,
trait_: TraitId,
item: AssocItemId,
cb: &mut F,
) -> ControlFlow<()>
where
F: FnMut(DynCompatibilityViolation) -> ControlFlow<()>,
{
if generics_require_sized_self(db, item.into()) {
return ControlFlow::Continue(());
}
match item {
AssocItemId::ConstId(it) => cb(DynCompatibilityViolation::AssocConst(it)),
AssocItemId::FunctionId(it) => {
virtual_call_violations_for_method(db, trait_, it, &mut |mvc| {
cb(DynCompatibilityViolation::Method(it, mvc))
})
}
AssocItemId::TypeAliasId(it) => {
let def_map = db.crate_def_map(trait_.krate(db.upcast()));
if def_map.is_unstable_feature_enabled(&intern::sym::generic_associated_type_extended) {
ControlFlow::Continue(())
} else {
let generic_params = db.generic_params(item.into());
if !generic_params.is_empty() {
cb(DynCompatibilityViolation::GAT(it))
} else {
ControlFlow::Continue(())
}
}
}
}
}
fn virtual_call_violations_for_method<F>(
db: &dyn HirDatabase,
trait_: TraitId,
func: FunctionId,
cb: &mut F,
) -> ControlFlow<()>
where
F: FnMut(MethodViolationCode) -> ControlFlow<()>,
{
let func_data = db.function_data(func);
if !func_data.has_self_param() {
cb(MethodViolationCode::StaticMethod)?;
}
if func_data.is_async() {
cb(MethodViolationCode::AsyncFn)?;
}
let sig = callable_item_sig(db, func.into());
if sig.skip_binders().params().iter().skip(1).any(|ty| {
contains_illegal_self_type_reference(
db,
func.into(),
trait_,
ty,
DebruijnIndex::INNERMOST,
AllowSelfProjection::Yes,
)
}) {
cb(MethodViolationCode::ReferencesSelfInput)?;
}
if contains_illegal_self_type_reference(
db,
func.into(),
trait_,
sig.skip_binders().ret(),
DebruijnIndex::INNERMOST,
AllowSelfProjection::Yes,
) {
cb(MethodViolationCode::ReferencesSelfOutput)?;
}
if !func_data.is_async() {
if let Some(mvc) = contains_illegal_impl_trait_in_trait(db, &sig) {
cb(mvc)?;
}
}
let generic_params = db.generic_params(func.into());
if generic_params.len_type_or_consts() > 0 {
cb(MethodViolationCode::Generic)?;
}
if func_data.has_self_param() && !receiver_is_dispatchable(db, trait_, func, &sig) {
cb(MethodViolationCode::UndispatchableReceiver)?;
}
let predicates = &*db.generic_predicates_without_parent(func.into());
let trait_self_idx = trait_self_param_idx(db.upcast(), func.into());
for pred in predicates {
let pred = pred.skip_binders().skip_binders();
if matches!(pred, WhereClause::TypeOutlives(_)) {
continue;
}
if let WhereClause::Implemented(TraitRef { trait_id, substitution }) = pred {
let trait_data = db.trait_data(from_chalk_trait_id(*trait_id));
if trait_data.flags.contains(TraitFlags::IS_AUTO)
&& substitution
.as_slice(Interner)
.first()
.and_then(|arg| arg.ty(Interner))
.and_then(|ty| ty.bound_var(Interner))
.is_some_and(|b| {
b.debruijn == DebruijnIndex::ONE && Some(b.index) == trait_self_idx
})
{
continue;
}
}
if contains_illegal_self_type_reference(
db,
func.into(),
trait_,
pred,
DebruijnIndex::ONE,
AllowSelfProjection::Yes,
) {
cb(MethodViolationCode::WhereClauseReferencesSelf)?;
break;
}
}
ControlFlow::Continue(())
}
fn receiver_is_dispatchable(
db: &dyn HirDatabase,
trait_: TraitId,
func: FunctionId,
sig: &Binders<CallableSig>,
) -> bool {
let Some(trait_self_idx) = trait_self_param_idx(db.upcast(), func.into()) else {
return false;
};
if sig
.skip_binders()
.params()
.first()
.and_then(|receiver| receiver.bound_var(Interner))
.is_some_and(|b| {
b == BoundVar { debruijn: DebruijnIndex::INNERMOST, index: trait_self_idx }
})
{
return true;
}
let placeholder_subst = generics(db.upcast(), func.into()).placeholder_subst(db);
let substituted_sig = sig.clone().substitute(Interner, &placeholder_subst);
let Some(receiver_ty) = substituted_sig.params().first() else {
return false;
};
let krate = func.module(db.upcast()).krate();
let traits = (
db.lang_item(krate, LangItem::Unsize).and_then(|it| it.as_trait()),
db.lang_item(krate, LangItem::DispatchFromDyn).and_then(|it| it.as_trait()),
);
let (Some(unsize_did), Some(dispatch_from_dyn_did)) = traits else {
return false;
};
let unsized_self_ty =
TyKind::Scalar(chalk_ir::Scalar::Uint(chalk_ir::UintTy::U32)).intern(Interner);
let Some(unsized_receiver_ty) = receiver_for_self_ty(db, func, unsized_self_ty.clone()) else {
return false;
};
let self_ty = placeholder_subst.as_slice(Interner)[trait_self_idx].assert_ty_ref(Interner);
let unsized_predicate = WhereClause::Implemented(TraitRef {
trait_id: to_chalk_trait_id(unsize_did),
substitution: Substitution::from_iter(Interner, [self_ty.clone(), unsized_self_ty.clone()]),
});
let trait_predicate = WhereClause::Implemented(TraitRef {
trait_id: to_chalk_trait_id(trait_),
substitution: Substitution::from_iter(
Interner,
std::iter::once(unsized_self_ty.clone().cast(Interner))
.chain(placeholder_subst.iter(Interner).skip(1).cloned()),
),
});
let generic_predicates = &*db.generic_predicates(func.into());
let clauses = std::iter::once(unsized_predicate)
.chain(std::iter::once(trait_predicate))
.chain(generic_predicates.iter().map(|pred| {
pred.clone().substitute(Interner, &placeholder_subst).into_value_and_skipped_binders().0
}))
.map(|pred| {
pred.cast::<chalk_ir::ProgramClause<Interner>>(Interner).into_from_env_clause(Interner)
});
let env = chalk_ir::Environment::new(Interner).add_clauses(Interner, clauses);
let obligation = WhereClause::Implemented(TraitRef {
trait_id: to_chalk_trait_id(dispatch_from_dyn_did),
substitution: Substitution::from_iter(Interner, [receiver_ty.clone(), unsized_receiver_ty]),
});
let goal = GoalData::DomainGoal(chalk_ir::DomainGoal::Holds(obligation)).intern(Interner);
let in_env = chalk_ir::InEnvironment::new(&env, goal);
let mut table = chalk_solve::infer::InferenceTable::<Interner>::new();
let canonicalized = table.canonicalize(Interner, in_env);
let solution = db.trait_solve(krate, None, canonicalized.quantified);
matches!(solution, Some(Solution::Unique(_)))
}
fn receiver_for_self_ty(db: &dyn HirDatabase, func: FunctionId, ty: Ty) -> Option<Ty> {
let generics = generics(db.upcast(), func.into());
let trait_self_idx = trait_self_param_idx(db.upcast(), func.into())?;
let subst = generics.placeholder_subst(db);
let subst = Substitution::from_iter(
Interner,
subst.iter(Interner).enumerate().map(|(idx, arg)| {
if idx == trait_self_idx {
ty.clone().cast(Interner)
} else {
arg.clone()
}
}),
);
let sig = callable_item_sig(db, func.into());
let sig = sig.substitute(Interner, &subst);
sig.params_and_return.first().cloned()
}
fn contains_illegal_impl_trait_in_trait(
db: &dyn HirDatabase,
sig: &Binders<CallableSig>,
) -> Option<MethodViolationCode> {
struct OpaqueTypeCollector(FxHashSet<OpaqueTyId>);
impl TypeVisitor<Interner> for OpaqueTypeCollector {
type BreakTy = ();
fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
self
}
fn interner(&self) -> Interner {
Interner
}
fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> {
if let TyKind::OpaqueType(opaque_ty_id, _) = ty.kind(Interner) {
self.0.insert(*opaque_ty_id);
}
ty.super_visit_with(self.as_dyn(), outer_binder)
}
}
let ret = sig.skip_binders().ret();
let mut visitor = OpaqueTypeCollector(FxHashSet::default());
ret.visit_with(visitor.as_dyn(), DebruijnIndex::INNERMOST);
for opaque_ty in visitor.0 {
let impl_trait_id = db.lookup_intern_impl_trait_id(opaque_ty.into());
if matches!(impl_trait_id, ImplTraitId::ReturnTypeImplTrait(..)) {
return Some(MethodViolationCode::ReferencesImplTraitInTrait);
}
}
None
}
#[cfg(test)]
mod tests;