hir_ty/infer/
cast.rs

1//! Type cast logic. Basically coercion + additional casts.
2
3use chalk_ir::{Mutability, Scalar, TyVariableKind, UintTy};
4use hir_def::{AdtId, hir::ExprId, signatures::TraitFlags};
5use stdx::never;
6
7use crate::{
8    Adjustment, Binders, DynTy, InferenceDiagnostic, Interner, PlaceholderIndex,
9    QuantifiedWhereClauses, Ty, TyExt, TyKind, TypeFlags, WhereClause,
10    db::HirDatabase,
11    from_chalk_trait_id,
12    infer::{coerce::CoerceNever, unify::InferenceTable},
13};
14
15#[derive(Debug)]
16pub(crate) enum Int {
17    I,
18    U(UintTy),
19    Bool,
20    Char,
21    CEnum,
22    InferenceVar,
23}
24
25#[derive(Debug)]
26pub(crate) enum CastTy {
27    Int(Int),
28    Float,
29    FnPtr,
30    Ptr(Ty, Mutability),
31    // `DynStar` is Not supported yet in r-a
32}
33
34impl CastTy {
35    pub(crate) fn from_ty(db: &dyn HirDatabase, t: &Ty) -> Option<Self> {
36        match t.kind(Interner) {
37            TyKind::Scalar(Scalar::Bool) => Some(Self::Int(Int::Bool)),
38            TyKind::Scalar(Scalar::Char) => Some(Self::Int(Int::Char)),
39            TyKind::Scalar(Scalar::Int(_)) => Some(Self::Int(Int::I)),
40            TyKind::Scalar(Scalar::Uint(it)) => Some(Self::Int(Int::U(*it))),
41            TyKind::InferenceVar(_, TyVariableKind::Integer) => Some(Self::Int(Int::InferenceVar)),
42            TyKind::InferenceVar(_, TyVariableKind::Float) => Some(Self::Float),
43            TyKind::Scalar(Scalar::Float(_)) => Some(Self::Float),
44            TyKind::Adt(..) => {
45                let (AdtId::EnumId(id), _) = t.as_adt()? else {
46                    return None;
47                };
48                let enum_data = id.enum_variants(db);
49                if enum_data.is_payload_free(db) { Some(Self::Int(Int::CEnum)) } else { None }
50            }
51            TyKind::Raw(m, ty) => Some(Self::Ptr(ty.clone(), *m)),
52            TyKind::Function(_) => Some(Self::FnPtr),
53            _ => None,
54        }
55    }
56}
57
58#[derive(Debug, PartialEq, Eq, Clone, Copy)]
59pub enum CastError {
60    Unknown,
61    CastToBool,
62    CastToChar,
63    DifferingKinds,
64    SizedUnsizedCast,
65    IllegalCast,
66    IntToFatCast,
67    NeedDeref,
68    NeedViaPtr,
69    NeedViaThinPtr,
70    NeedViaInt,
71    NonScalar,
72    // We don't want to report errors with unknown types currently.
73    // UnknownCastPtrKind,
74    // UnknownExprPtrKind,
75}
76
77impl CastError {
78    fn into_diagnostic(self, expr: ExprId, expr_ty: Ty, cast_ty: Ty) -> InferenceDiagnostic {
79        InferenceDiagnostic::InvalidCast { expr, error: self, expr_ty, cast_ty }
80    }
81}
82
83#[derive(Clone, Debug)]
84pub(super) struct CastCheck {
85    expr: ExprId,
86    source_expr: ExprId,
87    expr_ty: Ty,
88    cast_ty: Ty,
89}
90
91impl CastCheck {
92    pub(super) fn new(expr: ExprId, source_expr: ExprId, expr_ty: Ty, cast_ty: Ty) -> Self {
93        Self { expr, source_expr, expr_ty, cast_ty }
94    }
95
96    pub(super) fn check<F, G>(
97        &mut self,
98        table: &mut InferenceTable<'_>,
99        apply_adjustments: &mut F,
100        set_coercion_cast: &mut G,
101    ) -> Result<(), InferenceDiagnostic>
102    where
103        F: FnMut(ExprId, Vec<Adjustment>),
104        G: FnMut(ExprId),
105    {
106        self.expr_ty = table.eagerly_normalize_and_resolve_shallow_in(self.expr_ty.clone());
107        self.cast_ty = table.eagerly_normalize_and_resolve_shallow_in(self.cast_ty.clone());
108
109        if self.expr_ty.contains_unknown() || self.cast_ty.contains_unknown() {
110            return Ok(());
111        }
112
113        if !self.cast_ty.data(Interner).flags.contains(TypeFlags::HAS_TY_INFER)
114            && !table.is_sized(&self.cast_ty)
115        {
116            return Err(InferenceDiagnostic::CastToUnsized {
117                expr: self.expr,
118                cast_ty: self.cast_ty.clone(),
119            });
120        }
121
122        // Chalk doesn't support trait upcasting and fails to solve some obvious goals
123        // when the trait environment contains some recursive traits (See issue #18047)
124        // We skip cast checks for such cases for now, until the next-gen solver.
125        if contains_dyn_trait(&self.cast_ty) {
126            return Ok(());
127        }
128
129        if let Ok((adj, _)) = table.coerce(&self.expr_ty, &self.cast_ty, CoerceNever::Yes) {
130            apply_adjustments(self.source_expr, adj);
131            set_coercion_cast(self.source_expr);
132            return Ok(());
133        }
134
135        self.do_check(table, apply_adjustments)
136            .map_err(|e| e.into_diagnostic(self.expr, self.expr_ty.clone(), self.cast_ty.clone()))
137    }
138
139    fn do_check<F>(
140        &self,
141        table: &mut InferenceTable<'_>,
142        apply_adjustments: &mut F,
143    ) -> Result<(), CastError>
144    where
145        F: FnMut(ExprId, Vec<Adjustment>),
146    {
147        let (t_from, t_cast) = match (
148            CastTy::from_ty(table.db, &self.expr_ty),
149            CastTy::from_ty(table.db, &self.cast_ty),
150        ) {
151            (Some(t_from), Some(t_cast)) => (t_from, t_cast),
152            (None, Some(t_cast)) => match self.expr_ty.kind(Interner) {
153                TyKind::FnDef(..) => {
154                    let sig = self.expr_ty.callable_sig(table.db).expect("FnDef had no sig");
155                    let sig = table.eagerly_normalize_and_resolve_shallow_in(sig);
156                    let fn_ptr = TyKind::Function(sig.to_fn_ptr()).intern(Interner);
157                    if let Ok((adj, _)) = table.coerce(&self.expr_ty, &fn_ptr, CoerceNever::Yes) {
158                        apply_adjustments(self.source_expr, adj);
159                    } else {
160                        return Err(CastError::IllegalCast);
161                    }
162
163                    (CastTy::FnPtr, t_cast)
164                }
165                TyKind::Ref(mutbl, _, inner_ty) => {
166                    return match t_cast {
167                        CastTy::Int(_) | CastTy::Float => match inner_ty.kind(Interner) {
168                            TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_))
169                            | TyKind::InferenceVar(
170                                _,
171                                TyVariableKind::Integer | TyVariableKind::Float,
172                            ) => Err(CastError::NeedDeref),
173
174                            _ => Err(CastError::NeedViaPtr),
175                        },
176                        // array-ptr-cast
177                        CastTy::Ptr(t, m) => {
178                            let t = table.eagerly_normalize_and_resolve_shallow_in(t);
179                            if !table.is_sized(&t) {
180                                return Err(CastError::IllegalCast);
181                            }
182                            self.check_ref_cast(table, inner_ty, *mutbl, &t, m, apply_adjustments)
183                        }
184                        _ => Err(CastError::NonScalar),
185                    };
186                }
187                _ => return Err(CastError::NonScalar),
188            },
189            _ => return Err(CastError::NonScalar),
190        };
191
192        // rustc checks whether the `expr_ty` is foreign adt with `non_exhaustive` sym
193
194        match (t_from, t_cast) {
195            (_, CastTy::Int(Int::CEnum) | CastTy::FnPtr) => Err(CastError::NonScalar),
196            (_, CastTy::Int(Int::Bool)) => Err(CastError::CastToBool),
197            (CastTy::Int(Int::U(UintTy::U8)), CastTy::Int(Int::Char)) => Ok(()),
198            (_, CastTy::Int(Int::Char)) => Err(CastError::CastToChar),
199            (CastTy::Int(Int::Bool | Int::CEnum | Int::Char), CastTy::Float) => {
200                Err(CastError::NeedViaInt)
201            }
202            (CastTy::Int(Int::Bool | Int::CEnum | Int::Char) | CastTy::Float, CastTy::Ptr(..))
203            | (CastTy::Ptr(..) | CastTy::FnPtr, CastTy::Float) => Err(CastError::IllegalCast),
204            (CastTy::Ptr(src, _), CastTy::Ptr(dst, _)) => {
205                self.check_ptr_ptr_cast(table, &src, &dst)
206            }
207            (CastTy::Ptr(src, _), CastTy::Int(_)) => self.check_ptr_addr_cast(table, &src),
208            (CastTy::Int(_), CastTy::Ptr(dst, _)) => self.check_addr_ptr_cast(table, &dst),
209            (CastTy::FnPtr, CastTy::Ptr(dst, _)) => self.check_fptr_ptr_cast(table, &dst),
210            (CastTy::Int(Int::CEnum), CastTy::Int(_)) => Ok(()),
211            (CastTy::Int(Int::Char | Int::Bool), CastTy::Int(_)) => Ok(()),
212            (CastTy::Int(_) | CastTy::Float, CastTy::Int(_) | CastTy::Float) => Ok(()),
213            (CastTy::FnPtr, CastTy::Int(_)) => Ok(()),
214        }
215    }
216
217    fn check_ref_cast<F>(
218        &self,
219        table: &mut InferenceTable<'_>,
220        t_expr: &Ty,
221        m_expr: Mutability,
222        t_cast: &Ty,
223        m_cast: Mutability,
224        apply_adjustments: &mut F,
225    ) -> Result<(), CastError>
226    where
227        F: FnMut(ExprId, Vec<Adjustment>),
228    {
229        // Mutability order is opposite to rustc. `Mut < Not`
230        if m_expr <= m_cast
231            && let TyKind::Array(ety, _) = t_expr.kind(Interner)
232        {
233            // Coerce to a raw pointer so that we generate RawPtr in MIR.
234            let array_ptr_type = TyKind::Raw(m_expr, t_expr.clone()).intern(Interner);
235            if let Ok((adj, _)) = table.coerce(&self.expr_ty, &array_ptr_type, CoerceNever::Yes) {
236                apply_adjustments(self.source_expr, adj);
237            } else {
238                never!(
239                    "could not cast from reference to array to pointer to array ({:?} to {:?})",
240                    self.expr_ty,
241                    array_ptr_type
242                );
243            }
244
245            // This is a less strict condition than rustc's `demand_eqtype`,
246            // but false negative is better than false positive
247            if table.coerce(ety, t_cast, CoerceNever::Yes).is_ok() {
248                return Ok(());
249            }
250        }
251
252        Err(CastError::IllegalCast)
253    }
254
255    fn check_ptr_ptr_cast(
256        &self,
257        table: &mut InferenceTable<'_>,
258        src: &Ty,
259        dst: &Ty,
260    ) -> Result<(), CastError> {
261        let src_kind = pointer_kind(src, table).map_err(|_| CastError::Unknown)?;
262        let dst_kind = pointer_kind(dst, table).map_err(|_| CastError::Unknown)?;
263
264        match (src_kind, dst_kind) {
265            (Some(PointerKind::Error), _) | (_, Some(PointerKind::Error)) => Ok(()),
266            // (_, None) => Err(CastError::UnknownCastPtrKind),
267            // (None, _) => Err(CastError::UnknownExprPtrKind),
268            (_, None) | (None, _) => Ok(()),
269            (_, Some(PointerKind::Thin)) => Ok(()),
270            (Some(PointerKind::Thin), _) => Err(CastError::SizedUnsizedCast),
271            (Some(PointerKind::VTable(src_tty)), Some(PointerKind::VTable(dst_tty))) => {
272                let principal = |tty: &Binders<QuantifiedWhereClauses>| {
273                    tty.skip_binders().as_slice(Interner).first().and_then(|pred| {
274                        if let WhereClause::Implemented(tr) = pred.skip_binders() {
275                            Some(tr.trait_id)
276                        } else {
277                            None
278                        }
279                    })
280                };
281                match (principal(&src_tty), principal(&dst_tty)) {
282                    (Some(src_principal), Some(dst_principal)) => {
283                        if src_principal == dst_principal {
284                            return Ok(());
285                        }
286                        let src_principal =
287                            table.db.trait_signature(from_chalk_trait_id(src_principal));
288                        let dst_principal =
289                            table.db.trait_signature(from_chalk_trait_id(dst_principal));
290                        if src_principal.flags.contains(TraitFlags::AUTO)
291                            && dst_principal.flags.contains(TraitFlags::AUTO)
292                        {
293                            Ok(())
294                        } else {
295                            Err(CastError::DifferingKinds)
296                        }
297                    }
298                    _ => Err(CastError::Unknown),
299                }
300            }
301            (Some(src_kind), Some(dst_kind)) if src_kind == dst_kind => Ok(()),
302            (_, _) => Err(CastError::DifferingKinds),
303        }
304    }
305
306    fn check_ptr_addr_cast(
307        &self,
308        table: &mut InferenceTable<'_>,
309        expr_ty: &Ty,
310    ) -> Result<(), CastError> {
311        match pointer_kind(expr_ty, table).map_err(|_| CastError::Unknown)? {
312            // None => Err(CastError::UnknownExprPtrKind),
313            None => Ok(()),
314            Some(PointerKind::Error) => Ok(()),
315            Some(PointerKind::Thin) => Ok(()),
316            _ => Err(CastError::NeedViaThinPtr),
317        }
318    }
319
320    fn check_addr_ptr_cast(
321        &self,
322        table: &mut InferenceTable<'_>,
323        cast_ty: &Ty,
324    ) -> Result<(), CastError> {
325        match pointer_kind(cast_ty, table).map_err(|_| CastError::Unknown)? {
326            // None => Err(CastError::UnknownCastPtrKind),
327            None => Ok(()),
328            Some(PointerKind::Error) => Ok(()),
329            Some(PointerKind::Thin) => Ok(()),
330            Some(PointerKind::VTable(_)) => Err(CastError::IntToFatCast),
331            Some(PointerKind::Length) => Err(CastError::IntToFatCast),
332            Some(PointerKind::OfAlias | PointerKind::OfParam(_)) => Err(CastError::IntToFatCast),
333        }
334    }
335
336    fn check_fptr_ptr_cast(
337        &self,
338        table: &mut InferenceTable<'_>,
339        cast_ty: &Ty,
340    ) -> Result<(), CastError> {
341        match pointer_kind(cast_ty, table).map_err(|_| CastError::Unknown)? {
342            // None => Err(CastError::UnknownCastPtrKind),
343            None => Ok(()),
344            Some(PointerKind::Error) => Ok(()),
345            Some(PointerKind::Thin) => Ok(()),
346            _ => Err(CastError::IllegalCast),
347        }
348    }
349}
350
351#[derive(Debug, PartialEq, Eq)]
352enum PointerKind {
353    // thin pointer
354    Thin,
355    // trait object
356    VTable(Binders<QuantifiedWhereClauses>),
357    // slice
358    Length,
359    OfAlias,
360    OfParam(PlaceholderIndex),
361    Error,
362}
363
364fn pointer_kind(ty: &Ty, table: &mut InferenceTable<'_>) -> Result<Option<PointerKind>, ()> {
365    let ty = table.eagerly_normalize_and_resolve_shallow_in(ty.clone());
366
367    if table.is_sized(&ty) {
368        return Ok(Some(PointerKind::Thin));
369    }
370
371    match ty.kind(Interner) {
372        TyKind::Slice(_) | TyKind::Str => Ok(Some(PointerKind::Length)),
373        TyKind::Dyn(DynTy { bounds, .. }) => Ok(Some(PointerKind::VTable(bounds.clone()))),
374        TyKind::Adt(chalk_ir::AdtId(id), subst) => {
375            let AdtId::StructId(id) = *id else {
376                never!("`{:?}` should be sized but is not?", ty);
377                return Err(());
378            };
379
380            let struct_data = id.fields(table.db);
381            if let Some((last_field, _)) = struct_data.fields().iter().last() {
382                let last_field_ty =
383                    table.db.field_types(id.into())[last_field].clone().substitute(Interner, subst);
384                pointer_kind(&last_field_ty, table)
385            } else {
386                Ok(Some(PointerKind::Thin))
387            }
388        }
389        TyKind::Tuple(_, subst) => {
390            match subst.iter(Interner).last().and_then(|arg| arg.ty(Interner)) {
391                None => Ok(Some(PointerKind::Thin)),
392                Some(ty) => pointer_kind(ty, table),
393            }
394        }
395        TyKind::Foreign(_) => Ok(Some(PointerKind::Thin)),
396        TyKind::Alias(_) | TyKind::AssociatedType(..) | TyKind::OpaqueType(..) => {
397            Ok(Some(PointerKind::OfAlias))
398        }
399        TyKind::Error => Ok(Some(PointerKind::Error)),
400        TyKind::Placeholder(idx) => Ok(Some(PointerKind::OfParam(*idx))),
401        TyKind::BoundVar(_) | TyKind::InferenceVar(..) => Ok(None),
402        TyKind::Scalar(_)
403        | TyKind::Array(..)
404        | TyKind::CoroutineWitness(..)
405        | TyKind::Raw(..)
406        | TyKind::Ref(..)
407        | TyKind::FnDef(..)
408        | TyKind::Function(_)
409        | TyKind::Closure(..)
410        | TyKind::Coroutine(..)
411        | TyKind::Never => {
412            never!("`{:?}` should be sized but is not?", ty);
413            Err(())
414        }
415    }
416}
417
418fn contains_dyn_trait(ty: &Ty) -> bool {
419    use std::ops::ControlFlow;
420
421    use chalk_ir::{
422        DebruijnIndex,
423        visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor},
424    };
425
426    struct DynTraitVisitor;
427
428    impl TypeVisitor<Interner> for DynTraitVisitor {
429        type BreakTy = ();
430
431        fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
432            self
433        }
434
435        fn interner(&self) -> Interner {
436            Interner
437        }
438
439        fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> {
440            match ty.kind(Interner) {
441                TyKind::Dyn(_) => ControlFlow::Break(()),
442                _ => ty.super_visit_with(self.as_dyn(), outer_binder),
443            }
444        }
445    }
446
447    ty.visit_with(DynTraitVisitor.as_dyn(), DebruijnIndex::INNERMOST).is_break()
448}