hir_ty/diagnostics/
match_check.rs

1//! Validation of matches.
2//!
3//! This module provides lowering from [hir_def::hir::Pat] to [self::Pat] and match
4//! checking algorithm.
5//!
6//! It is modeled on the rustc module `rustc_mir_build::thir::pattern`.
7
8mod pat_util;
9
10pub(crate) mod pat_analysis;
11
12use hir_def::{
13    AdtId, EnumVariantId, LocalFieldId, Lookup, VariantId,
14    expr_store::{Body, path::Path},
15    hir::PatId,
16    item_tree::FieldsShape,
17};
18use hir_expand::name::Name;
19use rustc_type_ir::inherent::{IntoKind, SliceLike};
20use span::Edition;
21use stdx::{always, never, variance::PhantomCovariantLifetime};
22
23use crate::{
24    InferenceResult,
25    db::HirDatabase,
26    display::{HirDisplay, HirDisplayError, HirFormatter},
27    infer::BindingMode,
28    next_solver::{GenericArgs, Mutability, Ty, TyKind},
29};
30
31use self::pat_util::EnumerateAndAdjustIterator;
32
33#[derive(Clone, Debug)]
34pub(crate) enum PatternError {
35    Unimplemented,
36    UnexpectedType,
37    UnresolvedVariant,
38    MissingField,
39    ExtraFields,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43pub(crate) struct FieldPat<'db> {
44    pub(crate) field: LocalFieldId,
45    pub(crate) pattern: Pat<'db>,
46}
47
48#[derive(Clone, Debug, PartialEq)]
49pub(crate) struct Pat<'db> {
50    pub(crate) ty: Ty<'db>,
51    pub(crate) kind: Box<PatKind<'db>>,
52}
53
54/// Close relative to `rustc_mir_build::thir::pattern::PatKind`
55#[derive(Clone, Debug, PartialEq)]
56pub(crate) enum PatKind<'db> {
57    Wild,
58    Never,
59
60    /// `x`, `ref x`, `x @ P`, etc.
61    Binding {
62        name: Name,
63        subpattern: Option<Pat<'db>>,
64    },
65
66    /// `Foo(...)` or `Foo{...}` or `Foo`, where `Foo` is a variant name from an ADT with
67    /// multiple variants.
68    Variant {
69        substs: GenericArgs<'db>,
70        enum_variant: EnumVariantId,
71        subpatterns: Vec<FieldPat<'db>>,
72    },
73
74    /// `(...)`, `Foo(...)`, `Foo{...}`, or `Foo`, where `Foo` is a variant name from an ADT with
75    /// a single variant.
76    Leaf {
77        subpatterns: Vec<FieldPat<'db>>,
78    },
79
80    /// `&P`, `&mut P`, etc.
81    Deref {
82        subpattern: Pat<'db>,
83    },
84
85    // FIXME: for now, only bool literals are implemented
86    LiteralBool {
87        value: bool,
88    },
89
90    /// An or-pattern, e.g. `p | q`.
91    /// Invariant: `pats.len() >= 2`.
92    Or {
93        pats: Vec<Pat<'db>>,
94    },
95}
96
97pub(crate) struct PatCtxt<'a, 'db> {
98    db: &'db dyn HirDatabase,
99    infer: &'a InferenceResult<'db>,
100    body: &'a Body,
101    pub(crate) errors: Vec<PatternError>,
102}
103
104impl<'a, 'db> PatCtxt<'a, 'db> {
105    pub(crate) fn new(
106        db: &'db dyn HirDatabase,
107        infer: &'a InferenceResult<'db>,
108        body: &'a Body,
109    ) -> Self {
110        Self { db, infer, body, errors: Vec::new() }
111    }
112
113    pub(crate) fn lower_pattern(&mut self, pat: PatId) -> Pat<'db> {
114        // XXX(iDawer): Collecting pattern adjustments feels imprecise to me.
115        // When lowering of & and box patterns are implemented this should be tested
116        // in a manner of `match_ergonomics_issue_9095` test.
117        // Pattern adjustment is part of RFC 2005-match-ergonomics.
118        // More info https://github.com/rust-lang/rust/issues/42640#issuecomment-313535089
119        let unadjusted_pat = self.lower_pattern_unadjusted(pat);
120        self.infer.pat_adjustments.get(&pat).map(|it| &**it).unwrap_or_default().iter().rev().fold(
121            unadjusted_pat,
122            |subpattern, ref_ty| Pat { ty: *ref_ty, kind: Box::new(PatKind::Deref { subpattern }) },
123        )
124    }
125
126    fn lower_pattern_unadjusted(&mut self, pat: PatId) -> Pat<'db> {
127        let mut ty = self.infer[pat];
128        let variant = self.infer.variant_resolution_for_pat(pat);
129
130        let kind = match self.body[pat] {
131            hir_def::hir::Pat::Wild => PatKind::Wild,
132
133            hir_def::hir::Pat::Lit(expr) => self.lower_lit(expr),
134
135            hir_def::hir::Pat::Path(ref path) => {
136                return self.lower_path(pat, path);
137            }
138
139            hir_def::hir::Pat::Tuple { ref args, ellipsis } => {
140                let arity = match ty.kind() {
141                    TyKind::Tuple(tys) => tys.len(),
142                    _ => {
143                        never!("unexpected type for tuple pattern: {:?}", ty);
144                        self.errors.push(PatternError::UnexpectedType);
145                        return Pat { ty, kind: PatKind::Wild.into() };
146                    }
147                };
148                let subpatterns = self.lower_tuple_subpats(args, arity, ellipsis);
149                PatKind::Leaf { subpatterns }
150            }
151
152            hir_def::hir::Pat::Bind { id, subpat, .. } => {
153                let bm = self.infer.binding_modes[pat];
154                ty = self.infer[id];
155                let name = &self.body[id].name;
156                match (bm, ty.kind()) {
157                    (BindingMode::Ref(_), TyKind::Ref(_, rty, _)) => ty = rty,
158                    (BindingMode::Ref(_), _) => {
159                        never!(
160                            "`ref {}` has wrong type {:?}",
161                            name.display(self.db, Edition::LATEST),
162                            ty
163                        );
164                        self.errors.push(PatternError::UnexpectedType);
165                        return Pat { ty, kind: PatKind::Wild.into() };
166                    }
167                    _ => (),
168                }
169                PatKind::Binding { name: name.clone(), subpattern: self.lower_opt_pattern(subpat) }
170            }
171
172            hir_def::hir::Pat::TupleStruct { ref args, ellipsis, .. } if variant.is_some() => {
173                let expected_len = variant.unwrap().fields(self.db).fields().len();
174                let subpatterns = self.lower_tuple_subpats(args, expected_len, ellipsis);
175                self.lower_variant_or_leaf(pat, ty, subpatterns)
176            }
177
178            hir_def::hir::Pat::Record { ref args, .. } if variant.is_some() => {
179                let variant_data = variant.unwrap().fields(self.db);
180                let subpatterns = args
181                    .iter()
182                    .map(|field| {
183                        // XXX(iDawer): field lookup is inefficient
184                        variant_data.field(&field.name).map(|lfield_id| FieldPat {
185                            field: lfield_id,
186                            pattern: self.lower_pattern(field.pat),
187                        })
188                    })
189                    .collect();
190                match subpatterns {
191                    Some(subpatterns) => self.lower_variant_or_leaf(pat, ty, subpatterns),
192                    None => {
193                        self.errors.push(PatternError::MissingField);
194                        PatKind::Wild
195                    }
196                }
197            }
198            hir_def::hir::Pat::TupleStruct { .. } | hir_def::hir::Pat::Record { .. } => {
199                self.errors.push(PatternError::UnresolvedVariant);
200                PatKind::Wild
201            }
202
203            hir_def::hir::Pat::Or(ref pats) => PatKind::Or { pats: self.lower_patterns(pats) },
204
205            _ => {
206                self.errors.push(PatternError::Unimplemented);
207                PatKind::Wild
208            }
209        };
210
211        Pat { ty, kind: Box::new(kind) }
212    }
213
214    fn lower_tuple_subpats(
215        &mut self,
216        pats: &[PatId],
217        expected_len: usize,
218        ellipsis: Option<u32>,
219    ) -> Vec<FieldPat<'db>> {
220        if pats.len() > expected_len {
221            self.errors.push(PatternError::ExtraFields);
222            return Vec::new();
223        }
224
225        pats.iter()
226            .enumerate_and_adjust(expected_len, ellipsis.map(|it| it as usize))
227            .map(|(i, &subpattern)| FieldPat {
228                field: LocalFieldId::from_raw((i as u32).into()),
229                pattern: self.lower_pattern(subpattern),
230            })
231            .collect()
232    }
233
234    fn lower_patterns(&mut self, pats: &[PatId]) -> Vec<Pat<'db>> {
235        pats.iter().map(|&p| self.lower_pattern(p)).collect()
236    }
237
238    fn lower_opt_pattern(&mut self, pat: Option<PatId>) -> Option<Pat<'db>> {
239        pat.map(|p| self.lower_pattern(p))
240    }
241
242    fn lower_variant_or_leaf(
243        &mut self,
244        pat: PatId,
245        ty: Ty<'db>,
246        subpatterns: Vec<FieldPat<'db>>,
247    ) -> PatKind<'db> {
248        match self.infer.variant_resolution_for_pat(pat) {
249            Some(variant_id) => {
250                if let VariantId::EnumVariantId(enum_variant) = variant_id {
251                    let substs = match ty.kind() {
252                        TyKind::Adt(_, substs) => substs,
253                        kind => {
254                            always!(
255                                matches!(kind, TyKind::FnDef(..) | TyKind::Error(_)),
256                                "inappropriate type for def: {:?}",
257                                ty
258                            );
259                            self.errors.push(PatternError::UnexpectedType);
260                            return PatKind::Wild;
261                        }
262                    };
263                    PatKind::Variant { substs, enum_variant, subpatterns }
264                } else {
265                    PatKind::Leaf { subpatterns }
266                }
267            }
268            None => {
269                self.errors.push(PatternError::UnresolvedVariant);
270                PatKind::Wild
271            }
272        }
273    }
274
275    fn lower_path(&mut self, pat: PatId, _path: &Path) -> Pat<'db> {
276        let ty = self.infer[pat];
277
278        let pat_from_kind = |kind| Pat { ty, kind: Box::new(kind) };
279
280        match self.infer.variant_resolution_for_pat(pat) {
281            Some(_) => pat_from_kind(self.lower_variant_or_leaf(pat, ty, Vec::new())),
282            None => {
283                self.errors.push(PatternError::UnresolvedVariant);
284                pat_from_kind(PatKind::Wild)
285            }
286        }
287    }
288
289    fn lower_lit(&mut self, expr: hir_def::hir::ExprId) -> PatKind<'db> {
290        use hir_def::hir::{Expr, Literal::Bool};
291
292        match self.body[expr] {
293            Expr::Literal(Bool(value)) => PatKind::LiteralBool { value },
294            _ => {
295                self.errors.push(PatternError::Unimplemented);
296                PatKind::Wild
297            }
298        }
299    }
300}
301
302impl<'db> HirDisplay<'db> for Pat<'db> {
303    fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError> {
304        match &*self.kind {
305            PatKind::Wild => write!(f, "_"),
306            PatKind::Never => write!(f, "!"),
307            PatKind::Binding { name, subpattern } => {
308                write!(f, "{}", name.display(f.db, f.edition()))?;
309                if let Some(subpattern) = subpattern {
310                    write!(f, " @ ")?;
311                    subpattern.hir_fmt(f)?;
312                }
313                Ok(())
314            }
315            PatKind::Variant { subpatterns, .. } | PatKind::Leaf { subpatterns } => {
316                let variant = match *self.kind {
317                    PatKind::Variant { enum_variant, .. } => Some(VariantId::from(enum_variant)),
318                    _ => self.ty.as_adt().and_then(|(adt, _)| match adt {
319                        AdtId::StructId(s) => Some(s.into()),
320                        AdtId::UnionId(u) => Some(u.into()),
321                        AdtId::EnumId(_) => None,
322                    }),
323                };
324
325                if let Some(variant) = variant {
326                    match variant {
327                        VariantId::EnumVariantId(v) => {
328                            let loc = v.lookup(f.db);
329                            write!(
330                                f,
331                                "{}",
332                                loc.parent.enum_variants(f.db).variants[loc.index as usize]
333                                    .1
334                                    .display(f.db, f.edition())
335                            )?;
336                        }
337                        VariantId::StructId(s) => write!(
338                            f,
339                            "{}",
340                            f.db.struct_signature(s).name.display(f.db, f.edition())
341                        )?,
342                        VariantId::UnionId(u) => write!(
343                            f,
344                            "{}",
345                            f.db.union_signature(u).name.display(f.db, f.edition())
346                        )?,
347                    };
348
349                    let variant_data = variant.fields(f.db);
350                    if variant_data.shape == FieldsShape::Record {
351                        write!(f, " {{ ")?;
352
353                        let mut printed = 0;
354                        let subpats = subpatterns
355                            .iter()
356                            .filter(|p| !matches!(*p.pattern.kind, PatKind::Wild))
357                            .map(|p| {
358                                printed += 1;
359                                WriteWith::new(|f| {
360                                    write!(
361                                        f,
362                                        "{}: ",
363                                        variant_data.fields()[p.field]
364                                            .name
365                                            .display(f.db, f.edition())
366                                    )?;
367                                    p.pattern.hir_fmt(f)
368                                })
369                            });
370                        f.write_joined(subpats, ", ")?;
371
372                        if printed < variant_data.fields().len() {
373                            write!(f, "{}..", if printed > 0 { ", " } else { "" })?;
374                        }
375
376                        return write!(f, " }}");
377                    }
378                }
379
380                let num_fields =
381                    variant.map_or(subpatterns.len(), |v| v.fields(f.db).fields().len());
382                if num_fields != 0 || variant.is_none() {
383                    write!(f, "(")?;
384                    let subpats = (0..num_fields).map(|i| {
385                        WriteWith::new(move |f| {
386                            let fid = LocalFieldId::from_raw((i as u32).into());
387                            if let Some(p) = subpatterns.get(i)
388                                && p.field == fid
389                            {
390                                return p.pattern.hir_fmt(f);
391                            }
392                            if let Some(p) = subpatterns.iter().find(|p| p.field == fid) {
393                                p.pattern.hir_fmt(f)
394                            } else {
395                                write!(f, "_")
396                            }
397                        })
398                    });
399                    f.write_joined(subpats, ", ")?;
400                    if let (TyKind::Tuple(..), 1) = (self.ty.kind(), num_fields) {
401                        write!(f, ",")?;
402                    }
403                    write!(f, ")")?;
404                }
405
406                Ok(())
407            }
408            PatKind::Deref { subpattern } => {
409                match self.ty.kind() {
410                    TyKind::Ref(.., mutbl) => {
411                        write!(f, "&{}", if mutbl == Mutability::Mut { "mut " } else { "" })?
412                    }
413                    _ => never!("{:?} is a bad Deref pattern type", self.ty),
414                }
415                subpattern.hir_fmt(f)
416            }
417            PatKind::LiteralBool { value } => write!(f, "{value}"),
418            PatKind::Or { pats } => f.write_joined(pats.iter(), " | "),
419        }
420    }
421}
422
423struct WriteWith<'db, F>(F, PhantomCovariantLifetime<'db>)
424where
425    F: Fn(&mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError>;
426
427impl<'db, F> WriteWith<'db, F>
428where
429    F: Fn(&mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError>,
430{
431    fn new(f: F) -> Self {
432        Self(f, PhantomCovariantLifetime::new())
433    }
434}
435
436impl<'db, F> HirDisplay<'db> for WriteWith<'db, F>
437where
438    F: Fn(&mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError>,
439{
440    fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError> {
441        (self.0)(f)
442    }
443}