1mod pat_util;
9
10pub(crate) mod pat_analysis;
11
12use hir_def::{
13 AdtId, EnumVariantId, LocalFieldId, Lookup, VariantId,
14 expr_store::{Body, path::Path},
15 hir::PatId,
16 item_tree::FieldsShape,
17};
18use hir_expand::name::Name;
19use rustc_type_ir::inherent::{IntoKind, SliceLike};
20use span::Edition;
21use stdx::{always, never, variance::PhantomCovariantLifetime};
22
23use crate::{
24 InferenceResult,
25 db::HirDatabase,
26 display::{HirDisplay, HirDisplayError, HirFormatter},
27 infer::BindingMode,
28 next_solver::{GenericArgs, Mutability, Ty, TyKind},
29};
30
31use self::pat_util::EnumerateAndAdjustIterator;
32
33#[derive(Clone, Debug)]
34pub(crate) enum PatternError {
35 Unimplemented,
36 UnexpectedType,
37 UnresolvedVariant,
38 MissingField,
39 ExtraFields,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43pub(crate) struct FieldPat<'db> {
44 pub(crate) field: LocalFieldId,
45 pub(crate) pattern: Pat<'db>,
46}
47
48#[derive(Clone, Debug, PartialEq)]
49pub(crate) struct Pat<'db> {
50 pub(crate) ty: Ty<'db>,
51 pub(crate) kind: Box<PatKind<'db>>,
52}
53
54#[derive(Clone, Debug, PartialEq)]
56pub(crate) enum PatKind<'db> {
57 Wild,
58 Never,
59
60 Binding {
62 name: Name,
63 subpattern: Option<Pat<'db>>,
64 },
65
66 Variant {
69 substs: GenericArgs<'db>,
70 enum_variant: EnumVariantId,
71 subpatterns: Vec<FieldPat<'db>>,
72 },
73
74 Leaf {
77 subpatterns: Vec<FieldPat<'db>>,
78 },
79
80 Deref {
82 subpattern: Pat<'db>,
83 },
84
85 LiteralBool {
87 value: bool,
88 },
89
90 Or {
93 pats: Vec<Pat<'db>>,
94 },
95}
96
97pub(crate) struct PatCtxt<'a, 'db> {
98 db: &'db dyn HirDatabase,
99 infer: &'a InferenceResult<'db>,
100 body: &'a Body,
101 pub(crate) errors: Vec<PatternError>,
102}
103
104impl<'a, 'db> PatCtxt<'a, 'db> {
105 pub(crate) fn new(
106 db: &'db dyn HirDatabase,
107 infer: &'a InferenceResult<'db>,
108 body: &'a Body,
109 ) -> Self {
110 Self { db, infer, body, errors: Vec::new() }
111 }
112
113 pub(crate) fn lower_pattern(&mut self, pat: PatId) -> Pat<'db> {
114 let unadjusted_pat = self.lower_pattern_unadjusted(pat);
120 self.infer.pat_adjustments.get(&pat).map(|it| &**it).unwrap_or_default().iter().rev().fold(
121 unadjusted_pat,
122 |subpattern, ref_ty| Pat { ty: *ref_ty, kind: Box::new(PatKind::Deref { subpattern }) },
123 )
124 }
125
126 fn lower_pattern_unadjusted(&mut self, pat: PatId) -> Pat<'db> {
127 let mut ty = self.infer[pat];
128 let variant = self.infer.variant_resolution_for_pat(pat);
129
130 let kind = match self.body[pat] {
131 hir_def::hir::Pat::Wild => PatKind::Wild,
132
133 hir_def::hir::Pat::Lit(expr) => self.lower_lit(expr),
134
135 hir_def::hir::Pat::Path(ref path) => {
136 return self.lower_path(pat, path);
137 }
138
139 hir_def::hir::Pat::Tuple { ref args, ellipsis } => {
140 let arity = match ty.kind() {
141 TyKind::Tuple(tys) => tys.len(),
142 _ => {
143 never!("unexpected type for tuple pattern: {:?}", ty);
144 self.errors.push(PatternError::UnexpectedType);
145 return Pat { ty, kind: PatKind::Wild.into() };
146 }
147 };
148 let subpatterns = self.lower_tuple_subpats(args, arity, ellipsis);
149 PatKind::Leaf { subpatterns }
150 }
151
152 hir_def::hir::Pat::Bind { id, subpat, .. } => {
153 let bm = self.infer.binding_modes[pat];
154 ty = self.infer[id];
155 let name = &self.body[id].name;
156 match (bm, ty.kind()) {
157 (BindingMode::Ref(_), TyKind::Ref(_, rty, _)) => ty = rty,
158 (BindingMode::Ref(_), _) => {
159 never!(
160 "`ref {}` has wrong type {:?}",
161 name.display(self.db, Edition::LATEST),
162 ty
163 );
164 self.errors.push(PatternError::UnexpectedType);
165 return Pat { ty, kind: PatKind::Wild.into() };
166 }
167 _ => (),
168 }
169 PatKind::Binding { name: name.clone(), subpattern: self.lower_opt_pattern(subpat) }
170 }
171
172 hir_def::hir::Pat::TupleStruct { ref args, ellipsis, .. } if variant.is_some() => {
173 let expected_len = variant.unwrap().fields(self.db).fields().len();
174 let subpatterns = self.lower_tuple_subpats(args, expected_len, ellipsis);
175 self.lower_variant_or_leaf(pat, ty, subpatterns)
176 }
177
178 hir_def::hir::Pat::Record { ref args, .. } if variant.is_some() => {
179 let variant_data = variant.unwrap().fields(self.db);
180 let subpatterns = args
181 .iter()
182 .map(|field| {
183 variant_data.field(&field.name).map(|lfield_id| FieldPat {
185 field: lfield_id,
186 pattern: self.lower_pattern(field.pat),
187 })
188 })
189 .collect();
190 match subpatterns {
191 Some(subpatterns) => self.lower_variant_or_leaf(pat, ty, subpatterns),
192 None => {
193 self.errors.push(PatternError::MissingField);
194 PatKind::Wild
195 }
196 }
197 }
198 hir_def::hir::Pat::TupleStruct { .. } | hir_def::hir::Pat::Record { .. } => {
199 self.errors.push(PatternError::UnresolvedVariant);
200 PatKind::Wild
201 }
202
203 hir_def::hir::Pat::Or(ref pats) => PatKind::Or { pats: self.lower_patterns(pats) },
204
205 _ => {
206 self.errors.push(PatternError::Unimplemented);
207 PatKind::Wild
208 }
209 };
210
211 Pat { ty, kind: Box::new(kind) }
212 }
213
214 fn lower_tuple_subpats(
215 &mut self,
216 pats: &[PatId],
217 expected_len: usize,
218 ellipsis: Option<u32>,
219 ) -> Vec<FieldPat<'db>> {
220 if pats.len() > expected_len {
221 self.errors.push(PatternError::ExtraFields);
222 return Vec::new();
223 }
224
225 pats.iter()
226 .enumerate_and_adjust(expected_len, ellipsis.map(|it| it as usize))
227 .map(|(i, &subpattern)| FieldPat {
228 field: LocalFieldId::from_raw((i as u32).into()),
229 pattern: self.lower_pattern(subpattern),
230 })
231 .collect()
232 }
233
234 fn lower_patterns(&mut self, pats: &[PatId]) -> Vec<Pat<'db>> {
235 pats.iter().map(|&p| self.lower_pattern(p)).collect()
236 }
237
238 fn lower_opt_pattern(&mut self, pat: Option<PatId>) -> Option<Pat<'db>> {
239 pat.map(|p| self.lower_pattern(p))
240 }
241
242 fn lower_variant_or_leaf(
243 &mut self,
244 pat: PatId,
245 ty: Ty<'db>,
246 subpatterns: Vec<FieldPat<'db>>,
247 ) -> PatKind<'db> {
248 match self.infer.variant_resolution_for_pat(pat) {
249 Some(variant_id) => {
250 if let VariantId::EnumVariantId(enum_variant) = variant_id {
251 let substs = match ty.kind() {
252 TyKind::Adt(_, substs) => substs,
253 kind => {
254 always!(
255 matches!(kind, TyKind::FnDef(..) | TyKind::Error(_)),
256 "inappropriate type for def: {:?}",
257 ty
258 );
259 self.errors.push(PatternError::UnexpectedType);
260 return PatKind::Wild;
261 }
262 };
263 PatKind::Variant { substs, enum_variant, subpatterns }
264 } else {
265 PatKind::Leaf { subpatterns }
266 }
267 }
268 None => {
269 self.errors.push(PatternError::UnresolvedVariant);
270 PatKind::Wild
271 }
272 }
273 }
274
275 fn lower_path(&mut self, pat: PatId, _path: &Path) -> Pat<'db> {
276 let ty = self.infer[pat];
277
278 let pat_from_kind = |kind| Pat { ty, kind: Box::new(kind) };
279
280 match self.infer.variant_resolution_for_pat(pat) {
281 Some(_) => pat_from_kind(self.lower_variant_or_leaf(pat, ty, Vec::new())),
282 None => {
283 self.errors.push(PatternError::UnresolvedVariant);
284 pat_from_kind(PatKind::Wild)
285 }
286 }
287 }
288
289 fn lower_lit(&mut self, expr: hir_def::hir::ExprId) -> PatKind<'db> {
290 use hir_def::hir::{Expr, Literal::Bool};
291
292 match self.body[expr] {
293 Expr::Literal(Bool(value)) => PatKind::LiteralBool { value },
294 _ => {
295 self.errors.push(PatternError::Unimplemented);
296 PatKind::Wild
297 }
298 }
299 }
300}
301
302impl<'db> HirDisplay<'db> for Pat<'db> {
303 fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError> {
304 match &*self.kind {
305 PatKind::Wild => write!(f, "_"),
306 PatKind::Never => write!(f, "!"),
307 PatKind::Binding { name, subpattern } => {
308 write!(f, "{}", name.display(f.db, f.edition()))?;
309 if let Some(subpattern) = subpattern {
310 write!(f, " @ ")?;
311 subpattern.hir_fmt(f)?;
312 }
313 Ok(())
314 }
315 PatKind::Variant { subpatterns, .. } | PatKind::Leaf { subpatterns } => {
316 let variant = match *self.kind {
317 PatKind::Variant { enum_variant, .. } => Some(VariantId::from(enum_variant)),
318 _ => self.ty.as_adt().and_then(|(adt, _)| match adt {
319 AdtId::StructId(s) => Some(s.into()),
320 AdtId::UnionId(u) => Some(u.into()),
321 AdtId::EnumId(_) => None,
322 }),
323 };
324
325 if let Some(variant) = variant {
326 match variant {
327 VariantId::EnumVariantId(v) => {
328 let loc = v.lookup(f.db);
329 write!(
330 f,
331 "{}",
332 loc.parent.enum_variants(f.db).variants[loc.index as usize]
333 .1
334 .display(f.db, f.edition())
335 )?;
336 }
337 VariantId::StructId(s) => write!(
338 f,
339 "{}",
340 f.db.struct_signature(s).name.display(f.db, f.edition())
341 )?,
342 VariantId::UnionId(u) => write!(
343 f,
344 "{}",
345 f.db.union_signature(u).name.display(f.db, f.edition())
346 )?,
347 };
348
349 let variant_data = variant.fields(f.db);
350 if variant_data.shape == FieldsShape::Record {
351 write!(f, " {{ ")?;
352
353 let mut printed = 0;
354 let subpats = subpatterns
355 .iter()
356 .filter(|p| !matches!(*p.pattern.kind, PatKind::Wild))
357 .map(|p| {
358 printed += 1;
359 WriteWith::new(|f| {
360 write!(
361 f,
362 "{}: ",
363 variant_data.fields()[p.field]
364 .name
365 .display(f.db, f.edition())
366 )?;
367 p.pattern.hir_fmt(f)
368 })
369 });
370 f.write_joined(subpats, ", ")?;
371
372 if printed < variant_data.fields().len() {
373 write!(f, "{}..", if printed > 0 { ", " } else { "" })?;
374 }
375
376 return write!(f, " }}");
377 }
378 }
379
380 let num_fields =
381 variant.map_or(subpatterns.len(), |v| v.fields(f.db).fields().len());
382 if num_fields != 0 || variant.is_none() {
383 write!(f, "(")?;
384 let subpats = (0..num_fields).map(|i| {
385 WriteWith::new(move |f| {
386 let fid = LocalFieldId::from_raw((i as u32).into());
387 if let Some(p) = subpatterns.get(i)
388 && p.field == fid
389 {
390 return p.pattern.hir_fmt(f);
391 }
392 if let Some(p) = subpatterns.iter().find(|p| p.field == fid) {
393 p.pattern.hir_fmt(f)
394 } else {
395 write!(f, "_")
396 }
397 })
398 });
399 f.write_joined(subpats, ", ")?;
400 if let (TyKind::Tuple(..), 1) = (self.ty.kind(), num_fields) {
401 write!(f, ",")?;
402 }
403 write!(f, ")")?;
404 }
405
406 Ok(())
407 }
408 PatKind::Deref { subpattern } => {
409 match self.ty.kind() {
410 TyKind::Ref(.., mutbl) => {
411 write!(f, "&{}", if mutbl == Mutability::Mut { "mut " } else { "" })?
412 }
413 _ => never!("{:?} is a bad Deref pattern type", self.ty),
414 }
415 subpattern.hir_fmt(f)
416 }
417 PatKind::LiteralBool { value } => write!(f, "{value}"),
418 PatKind::Or { pats } => f.write_joined(pats.iter(), " | "),
419 }
420 }
421}
422
423struct WriteWith<'db, F>(F, PhantomCovariantLifetime<'db>)
424where
425 F: Fn(&mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError>;
426
427impl<'db, F> WriteWith<'db, F>
428where
429 F: Fn(&mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError>,
430{
431 fn new(f: F) -> Self {
432 Self(f, PhantomCovariantLifetime::new())
433 }
434}
435
436impl<'db, F> HirDisplay<'db> for WriteWith<'db, F>
437where
438 F: Fn(&mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError>,
439{
440 fn hir_fmt(&self, f: &mut HirFormatter<'_, 'db>) -> Result<(), HirDisplayError> {
441 (self.0)(f)
442 }
443}