use std::cell::LazyCell;
use std::fmt;
use hir_def::{DefWithBodyId, EnumId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId};
use intern::sym;
use rustc_pattern_analysis::{
constructor::{Constructor, ConstructorSet, VariantVisibility},
usefulness::{compute_match_usefulness, PlaceValidity, UsefulnessReport},
Captures, IndexVec, PatCx, PrivateUninhabitedField,
};
use smallvec::{smallvec, SmallVec};
use stdx::never;
use crate::{
db::HirDatabase,
infer::normalize,
inhabitedness::{is_enum_variant_uninhabited_from, is_ty_uninhabited_from},
AdtId, Interner, Scalar, Ty, TyExt, TyKind,
};
use super::{is_box, FieldPat, Pat, PatKind};
use Constructor::*;
pub(crate) type DeconstructedPat<'db> =
rustc_pattern_analysis::pat::DeconstructedPat<MatchCheckCtx<'db>>;
pub(crate) type MatchArm<'db> = rustc_pattern_analysis::MatchArm<'db, MatchCheckCtx<'db>>;
pub(crate) type WitnessPat<'db> = rustc_pattern_analysis::pat::WitnessPat<MatchCheckCtx<'db>>;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum Void {}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct EnumVariantContiguousIndex(usize);
impl EnumVariantContiguousIndex {
fn from_enum_variant_id(db: &dyn HirDatabase, target_evid: EnumVariantId) -> Self {
use hir_def::Lookup;
let i = target_evid.lookup(db.upcast()).index as usize;
EnumVariantContiguousIndex(i)
}
fn to_enum_variant_id(self, db: &dyn HirDatabase, eid: EnumId) -> EnumVariantId {
db.enum_data(eid).variants[self.0].0
}
}
impl rustc_pattern_analysis::Idx for EnumVariantContiguousIndex {
fn new(idx: usize) -> Self {
EnumVariantContiguousIndex(idx)
}
fn index(self) -> usize {
self.0
}
}
#[derive(Clone)]
pub(crate) struct MatchCheckCtx<'db> {
module: ModuleId,
body: DefWithBodyId,
pub(crate) db: &'db dyn HirDatabase,
exhaustive_patterns: bool,
}
impl<'db> MatchCheckCtx<'db> {
pub(crate) fn new(module: ModuleId, body: DefWithBodyId, db: &'db dyn HirDatabase) -> Self {
let def_map = db.crate_def_map(module.krate());
let exhaustive_patterns = def_map.is_unstable_feature_enabled(&sym::exhaustive_patterns);
Self { module, body, db, exhaustive_patterns }
}
pub(crate) fn compute_match_usefulness(
&self,
arms: &[MatchArm<'db>],
scrut_ty: Ty,
known_valid_scrutinee: Option<bool>,
) -> Result<UsefulnessReport<'db, Self>, ()> {
if scrut_ty.contains_unknown() {
return Err(());
}
for arm in arms {
if arm.pat.ty().contains_unknown() {
return Err(());
}
}
let place_validity = PlaceValidity::from_bool(known_valid_scrutinee.unwrap_or(true));
let complexity_limit = Some(500000);
compute_match_usefulness(self, arms, scrut_ty, place_validity, complexity_limit)
}
fn is_uninhabited(&self, ty: &Ty) -> bool {
is_ty_uninhabited_from(self.db, ty, self.module)
}
fn is_foreign_non_exhaustive(&self, adt: hir_def::AdtId) -> bool {
let is_local = adt.krate(self.db.upcast()) == self.module.krate();
!is_local && self.db.attrs(adt.into()).by_key(&sym::non_exhaustive).exists()
}
fn variant_id_for_adt(
db: &'db dyn HirDatabase,
ctor: &Constructor<Self>,
adt: hir_def::AdtId,
) -> Option<VariantId> {
match ctor {
Variant(id) => {
let hir_def::AdtId::EnumId(eid) = adt else {
panic!("bad constructor {ctor:?} for adt {adt:?}")
};
Some(id.to_enum_variant_id(db, eid).into())
}
Struct | UnionField => match adt {
hir_def::AdtId::EnumId(_) => None,
hir_def::AdtId::StructId(id) => Some(id.into()),
hir_def::AdtId::UnionId(id) => Some(id.into()),
},
_ => panic!("bad constructor {ctor:?} for adt {adt:?}"),
}
}
fn list_variant_fields<'a>(
&'a self,
ty: &'a Ty,
variant: VariantId,
) -> impl Iterator<Item = (LocalFieldId, Ty)> + Captures<'a> + Captures<'db> {
let (_, substs) = ty.as_adt().unwrap();
let field_tys = self.db.field_types(variant);
let fields_len = variant.variant_data(self.db.upcast()).fields().len() as u32;
(0..fields_len).map(|idx| LocalFieldId::from_raw(idx.into())).map(move |fid| {
let ty = field_tys[fid].clone().substitute(Interner, substs);
let ty = normalize(self.db, self.db.trait_environment_for_body(self.body), ty);
(fid, ty)
})
}
pub(crate) fn lower_pat(&self, pat: &Pat) -> DeconstructedPat<'db> {
let singleton = |pat: DeconstructedPat<'db>| vec![pat.at_index(0)];
let ctor;
let mut fields: Vec<_>;
let arity;
match pat.kind.as_ref() {
PatKind::Binding { subpattern: Some(subpat), .. } => return self.lower_pat(subpat),
PatKind::Binding { subpattern: None, .. } | PatKind::Wild => {
ctor = Wildcard;
fields = Vec::new();
arity = 0;
}
PatKind::Deref { subpattern } => {
ctor = match pat.ty.kind(Interner) {
TyKind::Adt(adt, _) if is_box(self.db, adt.0) => Struct,
TyKind::Ref(..) => Ref,
_ => {
never!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, &pat.ty);
Wildcard
}
};
fields = singleton(self.lower_pat(subpattern));
arity = 1;
}
PatKind::Leaf { subpatterns } | PatKind::Variant { subpatterns, .. } => {
fields = subpatterns
.iter()
.map(|pat| {
let idx: u32 = pat.field.into_raw().into();
self.lower_pat(&pat.pattern).at_index(idx as usize)
})
.collect();
match pat.ty.kind(Interner) {
TyKind::Tuple(_, substs) => {
ctor = Struct;
arity = substs.len(Interner);
}
TyKind::Adt(adt, _) if is_box(self.db, adt.0) => {
fields.retain(|ipat| ipat.idx == 0);
ctor = Struct;
arity = 1;
}
&TyKind::Adt(AdtId(adt), _) => {
ctor = match pat.kind.as_ref() {
PatKind::Leaf { .. } if matches!(adt, hir_def::AdtId::UnionId(_)) => {
UnionField
}
PatKind::Leaf { .. } => Struct,
PatKind::Variant { enum_variant, .. } => {
Variant(EnumVariantContiguousIndex::from_enum_variant_id(
self.db,
*enum_variant,
))
}
_ => {
never!();
Wildcard
}
};
let variant = Self::variant_id_for_adt(self.db, &ctor, adt).unwrap();
arity = variant.variant_data(self.db.upcast()).fields().len();
}
_ => {
never!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, &pat.ty);
ctor = Wildcard;
fields.clear();
arity = 0;
}
}
}
&PatKind::LiteralBool { value } => {
ctor = Bool(value);
fields = Vec::new();
arity = 0;
}
PatKind::Never => {
ctor = Never;
fields = Vec::new();
arity = 0;
}
PatKind::Or { pats } => {
ctor = Or;
fields = pats
.iter()
.enumerate()
.map(|(i, pat)| self.lower_pat(pat).at_index(i))
.collect();
arity = pats.len();
}
}
DeconstructedPat::new(ctor, fields, arity, pat.ty.clone(), ())
}
pub(crate) fn hoist_witness_pat(&self, pat: &WitnessPat<'db>) -> Pat {
let mut subpatterns = pat.iter_fields().map(|p| self.hoist_witness_pat(p));
let kind = match pat.ctor() {
&Bool(value) => PatKind::LiteralBool { value },
IntRange(_) => unimplemented!(),
Struct | Variant(_) | UnionField => match pat.ty().kind(Interner) {
TyKind::Tuple(..) => PatKind::Leaf {
subpatterns: subpatterns
.zip(0u32..)
.map(|(p, i)| FieldPat {
field: LocalFieldId::from_raw(i.into()),
pattern: p,
})
.collect(),
},
TyKind::Adt(adt, _) if is_box(self.db, adt.0) => {
PatKind::Deref { subpattern: subpatterns.next().unwrap() }
}
TyKind::Adt(adt, substs) => {
let variant = Self::variant_id_for_adt(self.db, pat.ctor(), adt.0).unwrap();
let subpatterns = self
.list_variant_fields(pat.ty(), variant)
.zip(subpatterns)
.map(|((field, _ty), pattern)| FieldPat { field, pattern })
.collect();
if let VariantId::EnumVariantId(enum_variant) = variant {
PatKind::Variant { substs: substs.clone(), enum_variant, subpatterns }
} else {
PatKind::Leaf { subpatterns }
}
}
_ => {
never!("unexpected ctor for type {:?} {:?}", pat.ctor(), pat.ty());
PatKind::Wild
}
},
Ref => PatKind::Deref { subpattern: subpatterns.next().unwrap() },
Slice(_) => unimplemented!(),
&Str(void) => match void {},
Wildcard | NonExhaustive | Hidden | PrivateUninhabited => PatKind::Wild,
Never => PatKind::Never,
Missing | F16Range(..) | F32Range(..) | F64Range(..) | F128Range(..) | Opaque(..)
| Or => {
never!("can't convert to pattern: {:?}", pat.ctor());
PatKind::Wild
}
};
Pat { ty: pat.ty().clone(), kind: Box::new(kind) }
}
}
impl<'db> PatCx for MatchCheckCtx<'db> {
type Error = ();
type Ty = Ty;
type VariantIdx = EnumVariantContiguousIndex;
type StrLit = Void;
type ArmData = ();
type PatData = ();
fn is_exhaustive_patterns_feature_on(&self) -> bool {
self.exhaustive_patterns
}
fn ctor_arity(
&self,
ctor: &rustc_pattern_analysis::constructor::Constructor<Self>,
ty: &Self::Ty,
) -> usize {
match ctor {
Struct | Variant(_) | UnionField => match *ty.kind(Interner) {
TyKind::Tuple(arity, ..) => arity,
TyKind::Adt(AdtId(adt), ..) => {
if is_box(self.db, adt) {
1
} else {
let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
variant.variant_data(self.db.upcast()).fields().len()
}
}
_ => {
never!("Unexpected type for `Single` constructor: {:?}", ty);
0
}
},
Ref => 1,
Slice(..) => unimplemented!(),
Never | Bool(..) | IntRange(..) | F16Range(..) | F32Range(..) | F64Range(..)
| F128Range(..) | Str(..) | Opaque(..) | NonExhaustive | PrivateUninhabited
| Hidden | Missing | Wildcard => 0,
Or => {
never!("The `Or` constructor doesn't have a fixed arity");
0
}
}
}
fn ctor_sub_tys<'a>(
&'a self,
ctor: &'a rustc_pattern_analysis::constructor::Constructor<Self>,
ty: &'a Self::Ty,
) -> impl ExactSizeIterator<Item = (Self::Ty, PrivateUninhabitedField)> + Captures<'a> {
let single = |ty| smallvec![(ty, PrivateUninhabitedField(false))];
let tys: SmallVec<[_; 2]> = match ctor {
Struct | Variant(_) | UnionField => match ty.kind(Interner) {
TyKind::Tuple(_, substs) => {
let tys = substs.iter(Interner).map(|ty| ty.assert_ty_ref(Interner));
tys.cloned().map(|ty| (ty, PrivateUninhabitedField(false))).collect()
}
TyKind::Ref(.., rty) => single(rty.clone()),
&TyKind::Adt(AdtId(adt), ref substs) => {
if is_box(self.db, adt) {
let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone();
single(subst_ty)
} else {
let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
let visibilities = LazyCell::new(|| self.db.field_visibilities(variant));
self.list_variant_fields(ty, variant)
.map(move |(fid, ty)| {
let is_visible = || {
matches!(adt, hir_def::AdtId::EnumId(..))
|| visibilities[fid]
.is_visible_from(self.db.upcast(), self.module)
};
let is_uninhabited = self.is_uninhabited(&ty);
let private_uninhabited = is_uninhabited && !is_visible();
(ty, PrivateUninhabitedField(private_uninhabited))
})
.collect()
}
}
ty_kind => {
never!("Unexpected type for `{:?}` constructor: {:?}", ctor, ty_kind);
single(ty.clone())
}
},
Ref => match ty.kind(Interner) {
TyKind::Ref(.., rty) => single(rty.clone()),
ty_kind => {
never!("Unexpected type for `{:?}` constructor: {:?}", ctor, ty_kind);
single(ty.clone())
}
},
Slice(_) => unreachable!("Found a `Slice` constructor in match checking"),
Never | Bool(..) | IntRange(..) | F16Range(..) | F32Range(..) | F64Range(..)
| F128Range(..) | Str(..) | Opaque(..) | NonExhaustive | PrivateUninhabited
| Hidden | Missing | Wildcard => {
smallvec![]
}
Or => {
never!("called `Fields::wildcards` on an `Or` ctor");
smallvec![]
}
};
tys.into_iter()
}
fn ctors_for_ty(
&self,
ty: &Self::Ty,
) -> Result<rustc_pattern_analysis::constructor::ConstructorSet<Self>, Self::Error> {
let cx = self;
let unhandled = || ConstructorSet::Unlistable;
Ok(match ty.kind(Interner) {
TyKind::Scalar(Scalar::Bool) => ConstructorSet::Bool,
TyKind::Scalar(Scalar::Char) => unhandled(),
TyKind::Scalar(Scalar::Int(..) | Scalar::Uint(..)) => unhandled(),
TyKind::Array(..) | TyKind::Slice(..) => unhandled(),
&TyKind::Adt(AdtId(adt @ hir_def::AdtId::EnumId(enum_id)), ref subst) => {
let enum_data = cx.db.enum_data(enum_id);
let is_declared_nonexhaustive = cx.is_foreign_non_exhaustive(adt);
if enum_data.variants.is_empty() && !is_declared_nonexhaustive {
ConstructorSet::NoConstructors
} else {
let mut variants = IndexVec::with_capacity(enum_data.variants.len());
for &(variant, _) in enum_data.variants.iter() {
let is_uninhabited =
is_enum_variant_uninhabited_from(cx.db, variant, subst, cx.module);
let visibility = if is_uninhabited {
VariantVisibility::Empty
} else {
VariantVisibility::Visible
};
variants.push(visibility);
}
ConstructorSet::Variants { variants, non_exhaustive: is_declared_nonexhaustive }
}
}
TyKind::Adt(AdtId(hir_def::AdtId::UnionId(_)), _) => ConstructorSet::Union,
TyKind::Adt(..) | TyKind::Tuple(..) => {
ConstructorSet::Struct { empty: cx.is_uninhabited(ty) }
}
TyKind::Ref(..) => ConstructorSet::Ref,
TyKind::Never => ConstructorSet::NoConstructors,
_ => ConstructorSet::Unlistable,
})
}
fn write_variant_name(
f: &mut fmt::Formatter<'_>,
_ctor: &Constructor<Self>,
_ty: &Self::Ty,
) -> fmt::Result {
write!(f, "<write_variant_name unsupported>")
}
fn bug(&self, fmt: fmt::Arguments<'_>) {
never!("{}", fmt)
}
fn complexity_exceeded(&self) -> Result<(), Self::Error> {
Err(())
}
}
impl fmt::Debug for MatchCheckCtx<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MatchCheckCtx").finish()
}
}