hir_ty/infer/
cast.rs

1//! Type cast logic. Basically coercion + additional casts.
2
3use hir_def::{AdtId, hir::ExprId, signatures::TraitFlags};
4use rustc_ast_ir::Mutability;
5use rustc_type_ir::{
6    Flags, InferTy, TypeFlags, UintTy,
7    inherent::{AdtDef, BoundExistentialPredicates as _, IntoKind, SliceLike, Ty as _},
8};
9use stdx::never;
10
11use crate::{
12    InferenceDiagnostic,
13    db::HirDatabase,
14    infer::{AllowTwoPhase, InferenceContext, expr::ExprIsRead},
15    next_solver::{BoundExistentialPredicates, DbInterner, ParamTy, Ty, TyKind},
16};
17
18#[derive(Debug)]
19pub(crate) enum Int {
20    I,
21    U(UintTy),
22    Bool,
23    Char,
24    CEnum,
25    InferenceVar,
26}
27
28#[derive(Debug)]
29pub(crate) enum CastTy<'db> {
30    Int(Int),
31    Float,
32    FnPtr,
33    Ptr(Ty<'db>, Mutability),
34    // `DynStar` is Not supported yet in r-a
35}
36
37impl<'db> CastTy<'db> {
38    pub(crate) fn from_ty(db: &dyn HirDatabase, t: Ty<'db>) -> Option<Self> {
39        match t.kind() {
40            TyKind::Bool => Some(Self::Int(Int::Bool)),
41            TyKind::Char => Some(Self::Int(Int::Char)),
42            TyKind::Int(_) => Some(Self::Int(Int::I)),
43            TyKind::Uint(it) => Some(Self::Int(Int::U(it))),
44            TyKind::Infer(InferTy::IntVar(_)) => Some(Self::Int(Int::InferenceVar)),
45            TyKind::Infer(InferTy::FloatVar(_)) => Some(Self::Float),
46            TyKind::Float(_) => Some(Self::Float),
47            TyKind::Adt(..) => {
48                let (AdtId::EnumId(id), _) = t.as_adt()? else {
49                    return None;
50                };
51                let enum_data = id.enum_variants(db);
52                if enum_data.is_payload_free(db) { Some(Self::Int(Int::CEnum)) } else { None }
53            }
54            TyKind::RawPtr(ty, m) => Some(Self::Ptr(ty, m)),
55            TyKind::FnPtr(..) => Some(Self::FnPtr),
56            _ => None,
57        }
58    }
59}
60
61#[derive(Debug, PartialEq, Eq, Clone, Copy)]
62pub enum CastError {
63    Unknown,
64    CastToBool,
65    CastToChar,
66    DifferingKinds,
67    SizedUnsizedCast,
68    IllegalCast,
69    IntToFatCast,
70    NeedDeref,
71    NeedViaPtr,
72    NeedViaThinPtr,
73    NeedViaInt,
74    NonScalar,
75    // We don't want to report errors with unknown types currently.
76    // UnknownCastPtrKind,
77    // UnknownExprPtrKind,
78}
79
80impl CastError {
81    fn into_diagnostic<'db>(
82        self,
83        expr: ExprId,
84        expr_ty: Ty<'db>,
85        cast_ty: Ty<'db>,
86    ) -> InferenceDiagnostic<'db> {
87        InferenceDiagnostic::InvalidCast { expr, error: self, expr_ty, cast_ty }
88    }
89}
90
91#[derive(Clone, Debug)]
92pub(super) struct CastCheck<'db> {
93    expr: ExprId,
94    source_expr: ExprId,
95    expr_ty: Ty<'db>,
96    cast_ty: Ty<'db>,
97}
98
99impl<'db> CastCheck<'db> {
100    pub(super) fn new(
101        expr: ExprId,
102        source_expr: ExprId,
103        expr_ty: Ty<'db>,
104        cast_ty: Ty<'db>,
105    ) -> Self {
106        Self { expr, source_expr, expr_ty, cast_ty }
107    }
108
109    pub(super) fn check(
110        &mut self,
111        ctx: &mut InferenceContext<'_, 'db>,
112    ) -> Result<(), InferenceDiagnostic<'db>> {
113        self.expr_ty = ctx.table.try_structurally_resolve_type(self.expr_ty);
114        self.cast_ty = ctx.table.try_structurally_resolve_type(self.cast_ty);
115
116        // This should always come first so that we apply the coercion, which impacts infer vars.
117        if ctx
118            .coerce(
119                self.source_expr.into(),
120                self.expr_ty,
121                self.cast_ty,
122                AllowTwoPhase::No,
123                ExprIsRead::Yes,
124            )
125            .is_ok()
126        {
127            ctx.result.coercion_casts.insert(self.source_expr);
128            return Ok(());
129        }
130
131        if self.expr_ty.references_non_lt_error() || self.cast_ty.references_non_lt_error() {
132            return Ok(());
133        }
134
135        if !self.cast_ty.flags().contains(TypeFlags::HAS_TY_INFER)
136            && !ctx.table.is_sized(self.cast_ty)
137        {
138            return Err(InferenceDiagnostic::CastToUnsized {
139                expr: self.expr,
140                cast_ty: self.cast_ty,
141            });
142        }
143
144        // Chalk doesn't support trait upcasting and fails to solve some obvious goals
145        // when the trait environment contains some recursive traits (See issue #18047)
146        // We skip cast checks for such cases for now, until the next-gen solver.
147        if contains_dyn_trait(self.cast_ty) {
148            return Ok(());
149        }
150
151        self.do_check(ctx).map_err(|e| e.into_diagnostic(self.expr, self.expr_ty, self.cast_ty))
152    }
153
154    fn do_check(&self, ctx: &mut InferenceContext<'_, 'db>) -> Result<(), CastError> {
155        let (t_from, t_cast) =
156            match (CastTy::from_ty(ctx.db, self.expr_ty), CastTy::from_ty(ctx.db, self.cast_ty)) {
157                (Some(t_from), Some(t_cast)) => (t_from, t_cast),
158                (None, Some(t_cast)) => match self.expr_ty.kind() {
159                    TyKind::FnDef(..) => {
160                        let sig =
161                            self.expr_ty.callable_sig(ctx.interner()).expect("FnDef had no sig");
162                        let sig = ctx.table.normalize_associated_types_in(sig);
163                        let fn_ptr = Ty::new_fn_ptr(ctx.interner(), sig);
164                        if ctx
165                            .coerce(
166                                self.source_expr.into(),
167                                self.expr_ty,
168                                fn_ptr,
169                                AllowTwoPhase::No,
170                                ExprIsRead::Yes,
171                            )
172                            .is_ok()
173                        {
174                        } else {
175                            return Err(CastError::IllegalCast);
176                        }
177
178                        (CastTy::FnPtr, t_cast)
179                    }
180                    TyKind::Ref(_, inner_ty, mutbl) => {
181                        return match t_cast {
182                            CastTy::Int(_) | CastTy::Float => match inner_ty.kind() {
183                                TyKind::Int(_)
184                                | TyKind::Uint(_)
185                                | TyKind::Float(_)
186                                | TyKind::Infer(InferTy::IntVar(_) | InferTy::FloatVar(_)) => {
187                                    Err(CastError::NeedDeref)
188                                }
189
190                                _ => Err(CastError::NeedViaPtr),
191                            },
192                            // array-ptr-cast
193                            CastTy::Ptr(t, m) => {
194                                let t = ctx.table.try_structurally_resolve_type(t);
195                                if !ctx.table.is_sized(t) {
196                                    return Err(CastError::IllegalCast);
197                                }
198                                self.check_ref_cast(ctx, inner_ty, mutbl, t, m)
199                            }
200                            _ => Err(CastError::NonScalar),
201                        };
202                    }
203                    _ => return Err(CastError::NonScalar),
204                },
205                _ => return Err(CastError::NonScalar),
206            };
207
208        // rustc checks whether the `expr_ty` is foreign adt with `non_exhaustive` sym
209
210        match (t_from, t_cast) {
211            (_, CastTy::Int(Int::CEnum) | CastTy::FnPtr) => Err(CastError::NonScalar),
212            (_, CastTy::Int(Int::Bool)) => Err(CastError::CastToBool),
213            (CastTy::Int(Int::U(UintTy::U8)), CastTy::Int(Int::Char)) => Ok(()),
214            (_, CastTy::Int(Int::Char)) => Err(CastError::CastToChar),
215            (CastTy::Int(Int::Bool | Int::CEnum | Int::Char), CastTy::Float) => {
216                Err(CastError::NeedViaInt)
217            }
218            (CastTy::Int(Int::Bool | Int::CEnum | Int::Char) | CastTy::Float, CastTy::Ptr(..))
219            | (CastTy::Ptr(..) | CastTy::FnPtr, CastTy::Float) => Err(CastError::IllegalCast),
220            (CastTy::Ptr(src, _), CastTy::Ptr(dst, _)) => self.check_ptr_ptr_cast(ctx, src, dst),
221            (CastTy::Ptr(src, _), CastTy::Int(_)) => self.check_ptr_addr_cast(ctx, src),
222            (CastTy::Int(_), CastTy::Ptr(dst, _)) => self.check_addr_ptr_cast(ctx, dst),
223            (CastTy::FnPtr, CastTy::Ptr(dst, _)) => self.check_fptr_ptr_cast(ctx, dst),
224            (CastTy::Int(Int::CEnum), CastTy::Int(_)) => Ok(()),
225            (CastTy::Int(Int::Char | Int::Bool), CastTy::Int(_)) => Ok(()),
226            (CastTy::Int(_) | CastTy::Float, CastTy::Int(_) | CastTy::Float) => Ok(()),
227            (CastTy::FnPtr, CastTy::Int(_)) => Ok(()),
228        }
229    }
230
231    fn check_ref_cast(
232        &self,
233        ctx: &mut InferenceContext<'_, 'db>,
234        t_expr: Ty<'db>,
235        m_expr: Mutability,
236        t_cast: Ty<'db>,
237        m_cast: Mutability,
238    ) -> Result<(), CastError> {
239        // Mutability order is opposite to rustc. `Mut < Not`
240        if m_expr <= m_cast
241            && let TyKind::Array(ety, _) = t_expr.kind()
242        {
243            // Coerce to a raw pointer so that we generate RawPtr in MIR.
244            let array_ptr_type = Ty::new_ptr(ctx.interner(), t_expr, m_expr);
245            if ctx
246                .coerce(
247                    self.source_expr.into(),
248                    self.expr_ty,
249                    array_ptr_type,
250                    AllowTwoPhase::No,
251                    ExprIsRead::Yes,
252                )
253                .is_ok()
254            {
255            } else {
256                never!(
257                    "could not cast from reference to array to pointer to array ({:?} to {:?})",
258                    self.expr_ty,
259                    array_ptr_type
260                );
261            }
262
263            // This is a less strict condition than rustc's `demand_eqtype`,
264            // but false negative is better than false positive
265            if ctx
266                .coerce(self.source_expr.into(), ety, t_cast, AllowTwoPhase::No, ExprIsRead::Yes)
267                .is_ok()
268            {
269                return Ok(());
270            }
271        }
272
273        Err(CastError::IllegalCast)
274    }
275
276    fn check_ptr_ptr_cast(
277        &self,
278        ctx: &mut InferenceContext<'_, 'db>,
279        src: Ty<'db>,
280        dst: Ty<'db>,
281    ) -> Result<(), CastError> {
282        let src_kind = pointer_kind(src, ctx).map_err(|_| CastError::Unknown)?;
283        let dst_kind = pointer_kind(dst, ctx).map_err(|_| CastError::Unknown)?;
284
285        match (src_kind, dst_kind) {
286            (Some(PointerKind::Error), _) | (_, Some(PointerKind::Error)) => Ok(()),
287            // (_, None) => Err(CastError::UnknownCastPtrKind),
288            // (None, _) => Err(CastError::UnknownExprPtrKind),
289            (_, None) | (None, _) => Ok(()),
290            (_, Some(PointerKind::Thin)) => Ok(()),
291            (Some(PointerKind::Thin), _) => Err(CastError::SizedUnsizedCast),
292            (Some(PointerKind::VTable(src_tty)), Some(PointerKind::VTable(dst_tty))) => {
293                match (src_tty.principal_def_id(), dst_tty.principal_def_id()) {
294                    (Some(src_principal), Some(dst_principal)) => {
295                        if src_principal == dst_principal {
296                            return Ok(());
297                        }
298                        let src_principal = ctx.db.trait_signature(src_principal.0);
299                        let dst_principal = ctx.db.trait_signature(dst_principal.0);
300                        if src_principal.flags.contains(TraitFlags::AUTO)
301                            && dst_principal.flags.contains(TraitFlags::AUTO)
302                        {
303                            Ok(())
304                        } else {
305                            Err(CastError::DifferingKinds)
306                        }
307                    }
308                    _ => Err(CastError::Unknown),
309                }
310            }
311            (Some(src_kind), Some(dst_kind)) if src_kind == dst_kind => Ok(()),
312            (_, _) => Err(CastError::DifferingKinds),
313        }
314    }
315
316    fn check_ptr_addr_cast(
317        &self,
318        ctx: &mut InferenceContext<'_, 'db>,
319        expr_ty: Ty<'db>,
320    ) -> Result<(), CastError> {
321        match pointer_kind(expr_ty, ctx).map_err(|_| CastError::Unknown)? {
322            // None => Err(CastError::UnknownExprPtrKind),
323            None => Ok(()),
324            Some(PointerKind::Error) => Ok(()),
325            Some(PointerKind::Thin) => Ok(()),
326            _ => Err(CastError::NeedViaThinPtr),
327        }
328    }
329
330    fn check_addr_ptr_cast(
331        &self,
332        ctx: &mut InferenceContext<'_, 'db>,
333        cast_ty: Ty<'db>,
334    ) -> Result<(), CastError> {
335        match pointer_kind(cast_ty, ctx).map_err(|_| CastError::Unknown)? {
336            // None => Err(CastError::UnknownCastPtrKind),
337            None => Ok(()),
338            Some(PointerKind::Error) => Ok(()),
339            Some(PointerKind::Thin) => Ok(()),
340            Some(PointerKind::VTable(_)) => Err(CastError::IntToFatCast),
341            Some(PointerKind::Length) => Err(CastError::IntToFatCast),
342            Some(PointerKind::OfAlias | PointerKind::OfParam(_)) => Err(CastError::IntToFatCast),
343        }
344    }
345
346    fn check_fptr_ptr_cast(
347        &self,
348        ctx: &mut InferenceContext<'_, 'db>,
349        cast_ty: Ty<'db>,
350    ) -> Result<(), CastError> {
351        match pointer_kind(cast_ty, ctx).map_err(|_| CastError::Unknown)? {
352            // None => Err(CastError::UnknownCastPtrKind),
353            None => Ok(()),
354            Some(PointerKind::Error) => Ok(()),
355            Some(PointerKind::Thin) => Ok(()),
356            _ => Err(CastError::IllegalCast),
357        }
358    }
359}
360
361#[derive(Debug, PartialEq, Eq)]
362enum PointerKind<'db> {
363    // thin pointer
364    Thin,
365    // trait object
366    VTable(BoundExistentialPredicates<'db>),
367    // slice
368    Length,
369    OfAlias,
370    OfParam(ParamTy),
371    Error,
372}
373
374fn pointer_kind<'db>(
375    ty: Ty<'db>,
376    ctx: &mut InferenceContext<'_, 'db>,
377) -> Result<Option<PointerKind<'db>>, ()> {
378    let ty = ctx.table.try_structurally_resolve_type(ty);
379
380    if ctx.table.is_sized(ty) {
381        return Ok(Some(PointerKind::Thin));
382    }
383
384    match ty.kind() {
385        TyKind::Slice(_) | TyKind::Str => Ok(Some(PointerKind::Length)),
386        TyKind::Dynamic(bounds, _) => Ok(Some(PointerKind::VTable(bounds))),
387        TyKind::Adt(adt_def, subst) => {
388            let id = adt_def.def_id().0;
389            let AdtId::StructId(id) = id else {
390                never!("`{:?}` should be sized but is not?", ty);
391                return Err(());
392            };
393
394            let struct_data = id.fields(ctx.db);
395            if let Some((last_field, _)) = struct_data.fields().iter().last() {
396                let last_field_ty =
397                    ctx.db.field_types(id.into())[last_field].instantiate(ctx.interner(), subst);
398                pointer_kind(last_field_ty, ctx)
399            } else {
400                Ok(Some(PointerKind::Thin))
401            }
402        }
403        TyKind::Tuple(subst) => match subst.iter().next_back() {
404            None => Ok(Some(PointerKind::Thin)),
405            Some(ty) => pointer_kind(ty, ctx),
406        },
407        TyKind::Foreign(_) => Ok(Some(PointerKind::Thin)),
408        TyKind::Alias(..) => Ok(Some(PointerKind::OfAlias)),
409        TyKind::Error(_) => Ok(Some(PointerKind::Error)),
410        TyKind::Param(idx) => Ok(Some(PointerKind::OfParam(idx))),
411        TyKind::Bound(..) | TyKind::Placeholder(..) | TyKind::Infer(..) => Ok(None),
412        TyKind::Int(_)
413        | TyKind::Uint(_)
414        | TyKind::Float(_)
415        | TyKind::Bool
416        | TyKind::Char
417        | TyKind::Array(..)
418        | TyKind::CoroutineWitness(..)
419        | TyKind::RawPtr(..)
420        | TyKind::Ref(..)
421        | TyKind::FnDef(..)
422        | TyKind::FnPtr(..)
423        | TyKind::Closure(..)
424        | TyKind::Coroutine(..)
425        | TyKind::CoroutineClosure(..)
426        | TyKind::Never => {
427            never!("`{:?}` should be sized but is not?", ty);
428            Err(())
429        }
430        TyKind::UnsafeBinder(..) | TyKind::Pat(..) => {
431            never!("we don't produce these types: {ty:?}");
432            Err(())
433        }
434    }
435}
436
437fn contains_dyn_trait<'db>(ty: Ty<'db>) -> bool {
438    use std::ops::ControlFlow;
439
440    use rustc_type_ir::{TypeSuperVisitable, TypeVisitable, TypeVisitor};
441
442    struct DynTraitVisitor;
443
444    impl<'db> TypeVisitor<DbInterner<'db>> for DynTraitVisitor {
445        type Result = ControlFlow<()>;
446
447        fn visit_ty(&mut self, ty: Ty<'db>) -> ControlFlow<()> {
448            match ty.kind() {
449                TyKind::Dynamic(..) => ControlFlow::Break(()),
450                _ => ty.super_visit_with(self),
451            }
452        }
453    }
454
455    ty.visit_with(&mut DynTraitVisitor).is_break()
456}