hir_ty/diagnostics/
expr.rs

1//! Various diagnostics for expressions that are collected together in one pass
2//! through the body using inference results: mismatched arg counts, missing
3//! fields, etc.
4
5use std::fmt;
6
7use base_db::Crate;
8use either::Either;
9use hir_def::{
10    AdtId, AssocItemId, DefWithBodyId, HasModule, ItemContainerId, Lookup,
11    lang_item::LangItem,
12    resolver::{HasResolver, ValueNs},
13};
14use intern::sym;
15use itertools::Itertools;
16use rustc_hash::FxHashSet;
17use rustc_pattern_analysis::constructor::Constructor;
18use syntax::{
19    AstNode,
20    ast::{self, UnaryOp},
21};
22use tracing::debug;
23use triomphe::Arc;
24use typed_arena::Arena;
25
26use crate::{
27    Adjust, InferenceResult, Interner, TraitEnvironment, Ty, TyExt, TyKind,
28    db::HirDatabase,
29    diagnostics::match_check::{
30        self,
31        pat_analysis::{self, DeconstructedPat, MatchCheckCtx, WitnessPat},
32    },
33    display::{DisplayTarget, HirDisplay},
34};
35
36pub(crate) use hir_def::{
37    LocalFieldId, VariantId,
38    expr_store::Body,
39    hir::{Expr, ExprId, MatchArm, Pat, PatId, Statement},
40};
41
42pub enum BodyValidationDiagnostic {
43    RecordMissingFields {
44        record: Either<ExprId, PatId>,
45        variant: VariantId,
46        missed_fields: Vec<LocalFieldId>,
47    },
48    ReplaceFilterMapNextWithFindMap {
49        method_call_expr: ExprId,
50    },
51    MissingMatchArms {
52        match_expr: ExprId,
53        uncovered_patterns: String,
54    },
55    NonExhaustiveLet {
56        pat: PatId,
57        uncovered_patterns: String,
58    },
59    RemoveTrailingReturn {
60        return_expr: ExprId,
61    },
62    RemoveUnnecessaryElse {
63        if_expr: ExprId,
64    },
65}
66
67impl BodyValidationDiagnostic {
68    pub fn collect(
69        db: &dyn HirDatabase,
70        owner: DefWithBodyId,
71        validate_lints: bool,
72    ) -> Vec<BodyValidationDiagnostic> {
73        let _p = tracing::info_span!("BodyValidationDiagnostic::collect").entered();
74        let infer = db.infer(owner);
75        let body = db.body(owner);
76        let env = db.trait_environment_for_body(owner);
77        let mut validator =
78            ExprValidator { owner, body, infer, diagnostics: Vec::new(), validate_lints, env };
79        validator.validate_body(db);
80        validator.diagnostics
81    }
82}
83
84struct ExprValidator {
85    owner: DefWithBodyId,
86    body: Arc<Body>,
87    infer: Arc<InferenceResult>,
88    env: Arc<TraitEnvironment>,
89    diagnostics: Vec<BodyValidationDiagnostic>,
90    validate_lints: bool,
91}
92
93impl ExprValidator {
94    fn validate_body(&mut self, db: &dyn HirDatabase) {
95        let mut filter_map_next_checker = None;
96        // we'll pass &mut self while iterating over body.exprs, so they need to be disjoint
97        let body = Arc::clone(&self.body);
98
99        if matches!(self.owner, DefWithBodyId::FunctionId(_)) {
100            self.check_for_trailing_return(body.body_expr, &body);
101        }
102
103        for (id, expr) in body.exprs() {
104            if let Some((variant, missed_fields, true)) =
105                record_literal_missing_fields(db, &self.infer, id, expr)
106            {
107                self.diagnostics.push(BodyValidationDiagnostic::RecordMissingFields {
108                    record: Either::Left(id),
109                    variant,
110                    missed_fields,
111                });
112            }
113
114            match expr {
115                Expr::Match { expr, arms } => {
116                    self.validate_match(id, *expr, arms, db);
117                }
118                Expr::Call { .. } | Expr::MethodCall { .. } => {
119                    self.validate_call(db, id, expr, &mut filter_map_next_checker);
120                }
121                Expr::Closure { body: body_expr, .. } => {
122                    self.check_for_trailing_return(*body_expr, &body);
123                }
124                Expr::If { .. } => {
125                    self.check_for_unnecessary_else(id, expr, db);
126                }
127                Expr::Block { .. } | Expr::Async { .. } | Expr::Unsafe { .. } => {
128                    self.validate_block(db, expr);
129                }
130                _ => {}
131            }
132        }
133
134        for (id, pat) in body.pats() {
135            if let Some((variant, missed_fields, true)) =
136                record_pattern_missing_fields(db, &self.infer, id, pat)
137            {
138                self.diagnostics.push(BodyValidationDiagnostic::RecordMissingFields {
139                    record: Either::Right(id),
140                    variant,
141                    missed_fields,
142                });
143            }
144        }
145    }
146
147    fn validate_call(
148        &mut self,
149        db: &dyn HirDatabase,
150        call_id: ExprId,
151        expr: &Expr,
152        filter_map_next_checker: &mut Option<FilterMapNextChecker>,
153    ) {
154        if !self.validate_lints {
155            return;
156        }
157        // Check that the number of arguments matches the number of parameters.
158
159        if self.infer.expr_type_mismatches().next().is_some() {
160            // FIXME: Due to shortcomings in the current type system implementation, only emit
161            // this diagnostic if there are no type mismatches in the containing function.
162        } else if let Expr::MethodCall { receiver, .. } = expr {
163            let (callee, _) = match self.infer.method_resolution(call_id) {
164                Some(it) => it,
165                None => return,
166            };
167
168            let checker = filter_map_next_checker
169                .get_or_insert_with(|| FilterMapNextChecker::new(&self.owner.resolver(db), db));
170
171            if checker.check(call_id, receiver, &callee).is_some() {
172                self.diagnostics.push(BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap {
173                    method_call_expr: call_id,
174                });
175            }
176
177            if let Some(receiver_ty) = self.infer.type_of_expr_with_adjust(*receiver) {
178                checker.prev_receiver_ty = Some(receiver_ty.clone());
179            }
180        }
181    }
182
183    fn validate_match(
184        &mut self,
185        match_expr: ExprId,
186        scrutinee_expr: ExprId,
187        arms: &[MatchArm],
188        db: &dyn HirDatabase,
189    ) {
190        let Some(scrut_ty) = self.infer.type_of_expr_with_adjust(scrutinee_expr) else {
191            return;
192        };
193        if scrut_ty.contains_unknown() {
194            return;
195        }
196
197        let cx = MatchCheckCtx::new(self.owner.module(db), self.owner, db, self.env.clone());
198
199        let pattern_arena = Arena::new();
200        let mut m_arms = Vec::with_capacity(arms.len());
201        let mut has_lowering_errors = false;
202        // Note: Skipping the entire diagnostic rather than just not including a faulty match arm is
203        // preferred to avoid the chance of false positives.
204        for arm in arms {
205            let Some(pat_ty) = self.infer.type_of_pat_with_adjust(arm.pat) else {
206                return;
207            };
208            if pat_ty.contains_unknown() {
209                return;
210            }
211
212            // We only include patterns whose type matches the type
213            // of the scrutinee expression. If we had an InvalidMatchArmPattern
214            // diagnostic or similar we could raise that in an else
215            // block here.
216            //
217            // When comparing the types, we also have to consider that rustc
218            // will automatically de-reference the scrutinee expression type if
219            // necessary.
220            //
221            // FIXME we should use the type checker for this.
222            if (pat_ty == scrut_ty
223                || scrut_ty
224                    .as_reference()
225                    .map(|(match_expr_ty, ..)| match_expr_ty == pat_ty)
226                    .unwrap_or(false))
227                && types_of_subpatterns_do_match(arm.pat, &self.body, &self.infer)
228            {
229                // If we had a NotUsefulMatchArm diagnostic, we could
230                // check the usefulness of each pattern as we added it
231                // to the matrix here.
232                let pat = self.lower_pattern(&cx, arm.pat, db, &mut has_lowering_errors);
233                let m_arm = pat_analysis::MatchArm {
234                    pat: pattern_arena.alloc(pat),
235                    has_guard: arm.guard.is_some(),
236                    arm_data: (),
237                };
238                m_arms.push(m_arm);
239                if !has_lowering_errors {
240                    continue;
241                }
242            }
243            // If the pattern type doesn't fit the match expression, we skip this diagnostic.
244            cov_mark::hit!(validate_match_bailed_out);
245            return;
246        }
247
248        let known_valid_scrutinee = Some(self.is_known_valid_scrutinee(scrutinee_expr, db));
249        let report = match cx.compute_match_usefulness(
250            m_arms.as_slice(),
251            scrut_ty.clone(),
252            known_valid_scrutinee,
253        ) {
254            Ok(report) => report,
255            Err(()) => return,
256        };
257
258        // FIXME Report unreachable arms
259        // https://github.com/rust-lang/rust/blob/f31622a50/compiler/rustc_mir_build/src/thir/pattern/check_match.rs#L200
260
261        let witnesses = report.non_exhaustiveness_witnesses;
262        if !witnesses.is_empty() {
263            self.diagnostics.push(BodyValidationDiagnostic::MissingMatchArms {
264                match_expr,
265                uncovered_patterns: missing_match_arms(
266                    &cx,
267                    scrut_ty,
268                    witnesses,
269                    m_arms.is_empty(),
270                    self.owner.krate(db),
271                ),
272            });
273        }
274    }
275
276    // [rustc's `is_known_valid_scrutinee`](https://github.com/rust-lang/rust/blob/c9bd03cb724e13cca96ad320733046cbdb16fbbe/compiler/rustc_mir_build/src/thir/pattern/check_match.rs#L288)
277    //
278    // While the above function in rustc uses thir exprs, r-a doesn't have them.
279    // So, the logic here is getting same result as "hir lowering + match with lowered thir"
280    // with "hir only"
281    fn is_known_valid_scrutinee(&self, scrutinee_expr: ExprId, db: &dyn HirDatabase) -> bool {
282        if self
283            .infer
284            .expr_adjustments
285            .get(&scrutinee_expr)
286            .is_some_and(|adjusts| adjusts.iter().any(|a| matches!(a.kind, Adjust::Deref(..))))
287        {
288            return false;
289        }
290
291        match &self.body[scrutinee_expr] {
292            Expr::UnaryOp { op: UnaryOp::Deref, .. } => false,
293            Expr::Path(path) => {
294                let value_or_partial = self.owner.resolver(db).resolve_path_in_value_ns_fully(
295                    db,
296                    path,
297                    self.body.expr_path_hygiene(scrutinee_expr),
298                );
299                value_or_partial.is_none_or(|v| !matches!(v, ValueNs::StaticId(_)))
300            }
301            Expr::Field { expr, .. } => match self.infer.type_of_expr[*expr].kind(Interner) {
302                TyKind::Adt(adt, ..) if matches!(adt.0, AdtId::UnionId(_)) => false,
303                _ => self.is_known_valid_scrutinee(*expr, db),
304            },
305            Expr::Index { base, .. } => self.is_known_valid_scrutinee(*base, db),
306            Expr::Cast { expr, .. } => self.is_known_valid_scrutinee(*expr, db),
307            Expr::Missing => false,
308            _ => true,
309        }
310    }
311
312    fn validate_block(&mut self, db: &dyn HirDatabase, expr: &Expr) {
313        let (Expr::Block { statements, .. }
314        | Expr::Async { statements, .. }
315        | Expr::Unsafe { statements, .. }) = expr
316        else {
317            return;
318        };
319        let pattern_arena = Arena::new();
320        let cx = MatchCheckCtx::new(self.owner.module(db), self.owner, db, self.env.clone());
321        for stmt in &**statements {
322            let &Statement::Let { pat, initializer, else_branch: None, .. } = stmt else {
323                continue;
324            };
325            if self.infer.type_mismatch_for_pat(pat).is_some() {
326                continue;
327            }
328            let Some(initializer) = initializer else { continue };
329            let Some(ty) = self.infer.type_of_expr_with_adjust(initializer) else { continue };
330            if ty.contains_unknown() {
331                continue;
332            }
333
334            let mut have_errors = false;
335            let deconstructed_pat = self.lower_pattern(&cx, pat, db, &mut have_errors);
336
337            // optimization, wildcard trivially hold
338            if have_errors || matches!(deconstructed_pat.ctor(), Constructor::Wildcard) {
339                continue;
340            }
341
342            let match_arm = rustc_pattern_analysis::MatchArm {
343                pat: pattern_arena.alloc(deconstructed_pat),
344                has_guard: false,
345                arm_data: (),
346            };
347            let report = match cx.compute_match_usefulness(&[match_arm], ty.clone(), None) {
348                Ok(v) => v,
349                Err(e) => {
350                    debug!(?e, "match usefulness error");
351                    continue;
352                }
353            };
354            let witnesses = report.non_exhaustiveness_witnesses;
355            if !witnesses.is_empty() {
356                self.diagnostics.push(BodyValidationDiagnostic::NonExhaustiveLet {
357                    pat,
358                    uncovered_patterns: missing_match_arms(
359                        &cx,
360                        ty,
361                        witnesses,
362                        false,
363                        self.owner.krate(db),
364                    ),
365                });
366            }
367        }
368    }
369
370    fn lower_pattern<'p>(
371        &self,
372        cx: &MatchCheckCtx<'p>,
373        pat: PatId,
374        db: &dyn HirDatabase,
375        have_errors: &mut bool,
376    ) -> DeconstructedPat<'p> {
377        let mut patcx = match_check::PatCtxt::new(db, &self.infer, &self.body);
378        let pattern = patcx.lower_pattern(pat);
379        let pattern = cx.lower_pat(&pattern);
380        if !patcx.errors.is_empty() {
381            *have_errors = true;
382        }
383        pattern
384    }
385
386    fn check_for_trailing_return(&mut self, body_expr: ExprId, body: &Body) {
387        if !self.validate_lints {
388            return;
389        }
390        match &body[body_expr] {
391            Expr::Block { statements, tail, .. } => {
392                let last_stmt = tail.or_else(|| match statements.last()? {
393                    Statement::Expr { expr, .. } => Some(*expr),
394                    _ => None,
395                });
396                if let Some(last_stmt) = last_stmt {
397                    self.check_for_trailing_return(last_stmt, body);
398                }
399            }
400            Expr::If { then_branch, else_branch, .. } => {
401                self.check_for_trailing_return(*then_branch, body);
402                if let Some(else_branch) = else_branch {
403                    self.check_for_trailing_return(*else_branch, body);
404                }
405            }
406            Expr::Match { arms, .. } => {
407                for arm in arms.iter() {
408                    let MatchArm { expr, .. } = arm;
409                    self.check_for_trailing_return(*expr, body);
410                }
411            }
412            Expr::Return { .. } => {
413                self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn {
414                    return_expr: body_expr,
415                });
416            }
417            _ => (),
418        }
419    }
420
421    fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) {
422        if !self.validate_lints {
423            return;
424        }
425        if let Expr::If { condition: _, then_branch, else_branch } = expr {
426            if else_branch.is_none() {
427                return;
428            }
429            if let Expr::Block { statements, tail, .. } = &self.body[*then_branch] {
430                let last_then_expr = tail.or_else(|| match statements.last()? {
431                    Statement::Expr { expr, .. } => Some(*expr),
432                    _ => None,
433                });
434                if let Some(last_then_expr) = last_then_expr
435                    && let Some(last_then_expr_ty) =
436                        self.infer.type_of_expr_with_adjust(last_then_expr)
437                    && last_then_expr_ty.is_never()
438                {
439                    // Only look at sources if the then branch diverges and we have an else branch.
440                    let source_map = db.body_with_source_map(self.owner).1;
441                    let Ok(source_ptr) = source_map.expr_syntax(id) else {
442                        return;
443                    };
444                    let root = source_ptr.file_syntax(db);
445                    let either::Left(ast::Expr::IfExpr(if_expr)) = source_ptr.value.to_node(&root)
446                    else {
447                        return;
448                    };
449                    let mut top_if_expr = if_expr;
450                    loop {
451                        let parent = top_if_expr.syntax().parent();
452                        let has_parent_expr_stmt_or_stmt_list =
453                            parent.as_ref().is_some_and(|node| {
454                                ast::ExprStmt::can_cast(node.kind())
455                                    | ast::StmtList::can_cast(node.kind())
456                            });
457                        if has_parent_expr_stmt_or_stmt_list {
458                            // Only emit diagnostic if parent or direct ancestor is either
459                            // an expr stmt or a stmt list.
460                            break;
461                        }
462                        let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
463                            // Bail if parent is neither an if expr, an expr stmt nor a stmt list.
464                            return;
465                        };
466                        // Check parent if expr.
467                        top_if_expr = parent_if_expr;
468                    }
469
470                    self.diagnostics
471                        .push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id })
472                }
473            }
474        }
475    }
476}
477
478struct FilterMapNextChecker {
479    filter_map_function_id: Option<hir_def::FunctionId>,
480    next_function_id: Option<hir_def::FunctionId>,
481    prev_filter_map_expr_id: Option<ExprId>,
482    prev_receiver_ty: Option<chalk_ir::Ty<Interner>>,
483}
484
485impl FilterMapNextChecker {
486    fn new(resolver: &hir_def::resolver::Resolver<'_>, db: &dyn HirDatabase) -> Self {
487        // Find and store the FunctionIds for Iterator::filter_map and Iterator::next
488        let (next_function_id, filter_map_function_id) = match LangItem::IteratorNext
489            .resolve_function(db, resolver.krate())
490        {
491            Some(next_function_id) => (
492                Some(next_function_id),
493                match next_function_id.lookup(db).container {
494                    ItemContainerId::TraitId(iterator_trait_id) => {
495                        let iterator_trait_items = &iterator_trait_id.trait_items(db).items;
496                        iterator_trait_items.iter().find_map(|(name, it)| match it {
497                            &AssocItemId::FunctionId(id) if *name == sym::filter_map => Some(id),
498                            _ => None,
499                        })
500                    }
501                    _ => None,
502                },
503            ),
504            None => (None, None),
505        };
506        Self {
507            filter_map_function_id,
508            next_function_id,
509            prev_filter_map_expr_id: None,
510            prev_receiver_ty: None,
511        }
512    }
513
514    // check for instances of .filter_map(..).next()
515    fn check(
516        &mut self,
517        current_expr_id: ExprId,
518        receiver_expr_id: &ExprId,
519        function_id: &hir_def::FunctionId,
520    ) -> Option<()> {
521        if *function_id == self.filter_map_function_id? {
522            self.prev_filter_map_expr_id = Some(current_expr_id);
523            return None;
524        }
525
526        if *function_id == self.next_function_id?
527            && let Some(prev_filter_map_expr_id) = self.prev_filter_map_expr_id
528        {
529            let is_dyn_trait = self
530                .prev_receiver_ty
531                .as_ref()
532                .is_some_and(|it| it.strip_references().dyn_trait().is_some());
533            if *receiver_expr_id == prev_filter_map_expr_id && !is_dyn_trait {
534                return Some(());
535            }
536        }
537
538        self.prev_filter_map_expr_id = None;
539        None
540    }
541}
542
543pub fn record_literal_missing_fields(
544    db: &dyn HirDatabase,
545    infer: &InferenceResult,
546    id: ExprId,
547    expr: &Expr,
548) -> Option<(VariantId, Vec<LocalFieldId>, /*exhaustive*/ bool)> {
549    let (fields, exhaustive) = match expr {
550        Expr::RecordLit { fields, spread, .. } => (fields, spread.is_none()),
551        _ => return None,
552    };
553
554    let variant_def = infer.variant_resolution_for_expr(id)?;
555    if let VariantId::UnionId(_) = variant_def {
556        return None;
557    }
558
559    let variant_data = variant_def.fields(db);
560
561    let specified_fields: FxHashSet<_> = fields.iter().map(|f| &f.name).collect();
562    let missed_fields: Vec<LocalFieldId> = variant_data
563        .fields()
564        .iter()
565        .filter_map(|(f, d)| if specified_fields.contains(&d.name) { None } else { Some(f) })
566        .collect();
567    if missed_fields.is_empty() {
568        return None;
569    }
570    Some((variant_def, missed_fields, exhaustive))
571}
572
573pub fn record_pattern_missing_fields(
574    db: &dyn HirDatabase,
575    infer: &InferenceResult,
576    id: PatId,
577    pat: &Pat,
578) -> Option<(VariantId, Vec<LocalFieldId>, /*exhaustive*/ bool)> {
579    let (fields, exhaustive) = match pat {
580        Pat::Record { path: _, args, ellipsis } => (args, !ellipsis),
581        _ => return None,
582    };
583
584    let variant_def = infer.variant_resolution_for_pat(id)?;
585    if let VariantId::UnionId(_) = variant_def {
586        return None;
587    }
588
589    let variant_data = variant_def.fields(db);
590
591    let specified_fields: FxHashSet<_> = fields.iter().map(|f| &f.name).collect();
592    let missed_fields: Vec<LocalFieldId> = variant_data
593        .fields()
594        .iter()
595        .filter_map(|(f, d)| if specified_fields.contains(&d.name) { None } else { Some(f) })
596        .collect();
597    if missed_fields.is_empty() {
598        return None;
599    }
600    Some((variant_def, missed_fields, exhaustive))
601}
602
603fn types_of_subpatterns_do_match(pat: PatId, body: &Body, infer: &InferenceResult) -> bool {
604    fn walk(pat: PatId, body: &Body, infer: &InferenceResult, has_type_mismatches: &mut bool) {
605        match infer.type_mismatch_for_pat(pat) {
606            Some(_) => *has_type_mismatches = true,
607            None if *has_type_mismatches => (),
608            None => {
609                let pat = &body[pat];
610                if let Pat::ConstBlock(expr) | Pat::Lit(expr) = *pat {
611                    *has_type_mismatches |= infer.type_mismatch_for_expr(expr).is_some();
612                    if *has_type_mismatches {
613                        return;
614                    }
615                }
616                pat.walk_child_pats(|subpat| walk(subpat, body, infer, has_type_mismatches))
617            }
618        }
619    }
620
621    let mut has_type_mismatches = false;
622    walk(pat, body, infer, &mut has_type_mismatches);
623    !has_type_mismatches
624}
625
626fn missing_match_arms<'p>(
627    cx: &MatchCheckCtx<'p>,
628    scrut_ty: &Ty,
629    witnesses: Vec<WitnessPat<'p>>,
630    arms_is_empty: bool,
631    krate: Crate,
632) -> String {
633    struct DisplayWitness<'a, 'p>(&'a WitnessPat<'p>, &'a MatchCheckCtx<'p>, DisplayTarget);
634    impl fmt::Display for DisplayWitness<'_, '_> {
635        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
636            let DisplayWitness(witness, cx, display_target) = *self;
637            let pat = cx.hoist_witness_pat(witness);
638            write!(f, "{}", pat.display(cx.db, display_target))
639        }
640    }
641
642    let non_empty_enum = match scrut_ty.as_adt() {
643        Some((AdtId::EnumId(e), _)) => !e.enum_variants(cx.db).variants.is_empty(),
644        _ => false,
645    };
646    let display_target = DisplayTarget::from_crate(cx.db, krate);
647    if arms_is_empty && !non_empty_enum {
648        format!("type `{}` is non-empty", scrut_ty.display(cx.db, display_target))
649    } else {
650        let pat_display = |witness| DisplayWitness(witness, cx, display_target);
651        const LIMIT: usize = 3;
652        match &*witnesses {
653            [witness] => format!("`{}` not covered", pat_display(witness)),
654            [head @ .., tail] if head.len() < LIMIT => {
655                let head = head.iter().map(pat_display);
656                format!("`{}` and `{}` not covered", head.format("`, `"), pat_display(tail))
657            }
658            _ => {
659                let (head, tail) = witnesses.split_at(LIMIT);
660                let head = head.iter().map(pat_display);
661                format!("`{}` and {} more not covered", head.format("`, `"), tail.len())
662            }
663        }
664    }
665}