1use hir_def::{AdtId, hir::ExprId, signatures::TraitFlags};
4use rustc_ast_ir::Mutability;
5use rustc_type_ir::{
6 Flags, InferTy, TypeFlags, UintTy,
7 inherent::{AdtDef, BoundExistentialPredicates as _, IntoKind, SliceLike, Ty as _},
8};
9use stdx::never;
10
11use crate::{
12 InferenceDiagnostic,
13 db::HirDatabase,
14 infer::{AllowTwoPhase, InferenceContext, expr::ExprIsRead},
15 next_solver::{BoundExistentialPredicates, DbInterner, ParamTy, Ty, TyKind},
16};
17
18#[derive(Debug)]
19pub(crate) enum Int {
20 I,
21 U(UintTy),
22 Bool,
23 Char,
24 CEnum,
25 InferenceVar,
26}
27
28#[derive(Debug)]
29pub(crate) enum CastTy<'db> {
30 Int(Int),
31 Float,
32 FnPtr,
33 Ptr(Ty<'db>, Mutability),
34 }
36
37impl<'db> CastTy<'db> {
38 pub(crate) fn from_ty(db: &dyn HirDatabase, t: Ty<'db>) -> Option<Self> {
39 match t.kind() {
40 TyKind::Bool => Some(Self::Int(Int::Bool)),
41 TyKind::Char => Some(Self::Int(Int::Char)),
42 TyKind::Int(_) => Some(Self::Int(Int::I)),
43 TyKind::Uint(it) => Some(Self::Int(Int::U(it))),
44 TyKind::Infer(InferTy::IntVar(_)) => Some(Self::Int(Int::InferenceVar)),
45 TyKind::Infer(InferTy::FloatVar(_)) => Some(Self::Float),
46 TyKind::Float(_) => Some(Self::Float),
47 TyKind::Adt(..) => {
48 let (AdtId::EnumId(id), _) = t.as_adt()? else {
49 return None;
50 };
51 let enum_data = id.enum_variants(db);
52 if enum_data.is_payload_free(db) { Some(Self::Int(Int::CEnum)) } else { None }
53 }
54 TyKind::RawPtr(ty, m) => Some(Self::Ptr(ty, m)),
55 TyKind::FnPtr(..) => Some(Self::FnPtr),
56 _ => None,
57 }
58 }
59}
60
61#[derive(Debug, PartialEq, Eq, Clone, Copy)]
62pub enum CastError {
63 Unknown,
64 CastToBool,
65 CastToChar,
66 DifferingKinds,
67 SizedUnsizedCast,
68 IllegalCast,
69 IntToFatCast,
70 NeedDeref,
71 NeedViaPtr,
72 NeedViaThinPtr,
73 NeedViaInt,
74 NonScalar,
75 }
79
80impl CastError {
81 fn into_diagnostic<'db>(
82 self,
83 expr: ExprId,
84 expr_ty: Ty<'db>,
85 cast_ty: Ty<'db>,
86 ) -> InferenceDiagnostic<'db> {
87 InferenceDiagnostic::InvalidCast { expr, error: self, expr_ty, cast_ty }
88 }
89}
90
91#[derive(Clone, Debug)]
92pub(super) struct CastCheck<'db> {
93 expr: ExprId,
94 source_expr: ExprId,
95 expr_ty: Ty<'db>,
96 cast_ty: Ty<'db>,
97}
98
99impl<'db> CastCheck<'db> {
100 pub(super) fn new(
101 expr: ExprId,
102 source_expr: ExprId,
103 expr_ty: Ty<'db>,
104 cast_ty: Ty<'db>,
105 ) -> Self {
106 Self { expr, source_expr, expr_ty, cast_ty }
107 }
108
109 pub(super) fn check(
110 &mut self,
111 ctx: &mut InferenceContext<'_, 'db>,
112 ) -> Result<(), InferenceDiagnostic<'db>> {
113 self.expr_ty = ctx.table.try_structurally_resolve_type(self.expr_ty);
114 self.cast_ty = ctx.table.try_structurally_resolve_type(self.cast_ty);
115
116 if ctx
118 .coerce(
119 self.source_expr.into(),
120 self.expr_ty,
121 self.cast_ty,
122 AllowTwoPhase::No,
123 ExprIsRead::Yes,
124 )
125 .is_ok()
126 {
127 ctx.result.coercion_casts.insert(self.source_expr);
128 return Ok(());
129 }
130
131 if self.expr_ty.references_non_lt_error() || self.cast_ty.references_non_lt_error() {
132 return Ok(());
133 }
134
135 if !self.cast_ty.flags().contains(TypeFlags::HAS_TY_INFER)
136 && !ctx.table.is_sized(self.cast_ty)
137 {
138 return Err(InferenceDiagnostic::CastToUnsized {
139 expr: self.expr,
140 cast_ty: self.cast_ty,
141 });
142 }
143
144 if contains_dyn_trait(self.cast_ty) {
148 return Ok(());
149 }
150
151 self.do_check(ctx).map_err(|e| e.into_diagnostic(self.expr, self.expr_ty, self.cast_ty))
152 }
153
154 fn do_check(&self, ctx: &mut InferenceContext<'_, 'db>) -> Result<(), CastError> {
155 let (t_from, t_cast) =
156 match (CastTy::from_ty(ctx.db, self.expr_ty), CastTy::from_ty(ctx.db, self.cast_ty)) {
157 (Some(t_from), Some(t_cast)) => (t_from, t_cast),
158 (None, Some(t_cast)) => match self.expr_ty.kind() {
159 TyKind::FnDef(..) => {
160 let sig =
161 self.expr_ty.callable_sig(ctx.interner()).expect("FnDef had no sig");
162 let sig = ctx.table.normalize_associated_types_in(sig);
163 let fn_ptr = Ty::new_fn_ptr(ctx.interner(), sig);
164 if ctx
165 .coerce(
166 self.source_expr.into(),
167 self.expr_ty,
168 fn_ptr,
169 AllowTwoPhase::No,
170 ExprIsRead::Yes,
171 )
172 .is_ok()
173 {
174 } else {
175 return Err(CastError::IllegalCast);
176 }
177
178 (CastTy::FnPtr, t_cast)
179 }
180 TyKind::Ref(_, inner_ty, mutbl) => {
181 return match t_cast {
182 CastTy::Int(_) | CastTy::Float => match inner_ty.kind() {
183 TyKind::Int(_)
184 | TyKind::Uint(_)
185 | TyKind::Float(_)
186 | TyKind::Infer(InferTy::IntVar(_) | InferTy::FloatVar(_)) => {
187 Err(CastError::NeedDeref)
188 }
189
190 _ => Err(CastError::NeedViaPtr),
191 },
192 CastTy::Ptr(t, m) => {
194 let t = ctx.table.try_structurally_resolve_type(t);
195 if !ctx.table.is_sized(t) {
196 return Err(CastError::IllegalCast);
197 }
198 self.check_ref_cast(ctx, inner_ty, mutbl, t, m)
199 }
200 _ => Err(CastError::NonScalar),
201 };
202 }
203 _ => return Err(CastError::NonScalar),
204 },
205 _ => return Err(CastError::NonScalar),
206 };
207
208 match (t_from, t_cast) {
211 (_, CastTy::Int(Int::CEnum) | CastTy::FnPtr) => Err(CastError::NonScalar),
212 (_, CastTy::Int(Int::Bool)) => Err(CastError::CastToBool),
213 (CastTy::Int(Int::U(UintTy::U8)), CastTy::Int(Int::Char)) => Ok(()),
214 (_, CastTy::Int(Int::Char)) => Err(CastError::CastToChar),
215 (CastTy::Int(Int::Bool | Int::CEnum | Int::Char), CastTy::Float) => {
216 Err(CastError::NeedViaInt)
217 }
218 (CastTy::Int(Int::Bool | Int::CEnum | Int::Char) | CastTy::Float, CastTy::Ptr(..))
219 | (CastTy::Ptr(..) | CastTy::FnPtr, CastTy::Float) => Err(CastError::IllegalCast),
220 (CastTy::Ptr(src, _), CastTy::Ptr(dst, _)) => self.check_ptr_ptr_cast(ctx, src, dst),
221 (CastTy::Ptr(src, _), CastTy::Int(_)) => self.check_ptr_addr_cast(ctx, src),
222 (CastTy::Int(_), CastTy::Ptr(dst, _)) => self.check_addr_ptr_cast(ctx, dst),
223 (CastTy::FnPtr, CastTy::Ptr(dst, _)) => self.check_fptr_ptr_cast(ctx, dst),
224 (CastTy::Int(Int::CEnum), CastTy::Int(_)) => Ok(()),
225 (CastTy::Int(Int::Char | Int::Bool), CastTy::Int(_)) => Ok(()),
226 (CastTy::Int(_) | CastTy::Float, CastTy::Int(_) | CastTy::Float) => Ok(()),
227 (CastTy::FnPtr, CastTy::Int(_)) => Ok(()),
228 }
229 }
230
231 fn check_ref_cast(
232 &self,
233 ctx: &mut InferenceContext<'_, 'db>,
234 t_expr: Ty<'db>,
235 m_expr: Mutability,
236 t_cast: Ty<'db>,
237 m_cast: Mutability,
238 ) -> Result<(), CastError> {
239 if m_expr <= m_cast
241 && let TyKind::Array(ety, _) = t_expr.kind()
242 {
243 let array_ptr_type = Ty::new_ptr(ctx.interner(), t_expr, m_expr);
245 if ctx
246 .coerce(
247 self.source_expr.into(),
248 self.expr_ty,
249 array_ptr_type,
250 AllowTwoPhase::No,
251 ExprIsRead::Yes,
252 )
253 .is_ok()
254 {
255 } else {
256 never!(
257 "could not cast from reference to array to pointer to array ({:?} to {:?})",
258 self.expr_ty,
259 array_ptr_type
260 );
261 }
262
263 if ctx
266 .coerce(self.source_expr.into(), ety, t_cast, AllowTwoPhase::No, ExprIsRead::Yes)
267 .is_ok()
268 {
269 return Ok(());
270 }
271 }
272
273 Err(CastError::IllegalCast)
274 }
275
276 fn check_ptr_ptr_cast(
277 &self,
278 ctx: &mut InferenceContext<'_, 'db>,
279 src: Ty<'db>,
280 dst: Ty<'db>,
281 ) -> Result<(), CastError> {
282 let src_kind = pointer_kind(src, ctx).map_err(|_| CastError::Unknown)?;
283 let dst_kind = pointer_kind(dst, ctx).map_err(|_| CastError::Unknown)?;
284
285 match (src_kind, dst_kind) {
286 (Some(PointerKind::Error), _) | (_, Some(PointerKind::Error)) => Ok(()),
287 (_, None) | (None, _) => Ok(()),
290 (_, Some(PointerKind::Thin)) => Ok(()),
291 (Some(PointerKind::Thin), _) => Err(CastError::SizedUnsizedCast),
292 (Some(PointerKind::VTable(src_tty)), Some(PointerKind::VTable(dst_tty))) => {
293 match (src_tty.principal_def_id(), dst_tty.principal_def_id()) {
294 (Some(src_principal), Some(dst_principal)) => {
295 if src_principal == dst_principal {
296 return Ok(());
297 }
298 let src_principal = ctx.db.trait_signature(src_principal.0);
299 let dst_principal = ctx.db.trait_signature(dst_principal.0);
300 if src_principal.flags.contains(TraitFlags::AUTO)
301 && dst_principal.flags.contains(TraitFlags::AUTO)
302 {
303 Ok(())
304 } else {
305 Err(CastError::DifferingKinds)
306 }
307 }
308 _ => Err(CastError::Unknown),
309 }
310 }
311 (Some(src_kind), Some(dst_kind)) if src_kind == dst_kind => Ok(()),
312 (_, _) => Err(CastError::DifferingKinds),
313 }
314 }
315
316 fn check_ptr_addr_cast(
317 &self,
318 ctx: &mut InferenceContext<'_, 'db>,
319 expr_ty: Ty<'db>,
320 ) -> Result<(), CastError> {
321 match pointer_kind(expr_ty, ctx).map_err(|_| CastError::Unknown)? {
322 None => Ok(()),
324 Some(PointerKind::Error) => Ok(()),
325 Some(PointerKind::Thin) => Ok(()),
326 _ => Err(CastError::NeedViaThinPtr),
327 }
328 }
329
330 fn check_addr_ptr_cast(
331 &self,
332 ctx: &mut InferenceContext<'_, 'db>,
333 cast_ty: Ty<'db>,
334 ) -> Result<(), CastError> {
335 match pointer_kind(cast_ty, ctx).map_err(|_| CastError::Unknown)? {
336 None => Ok(()),
338 Some(PointerKind::Error) => Ok(()),
339 Some(PointerKind::Thin) => Ok(()),
340 Some(PointerKind::VTable(_)) => Err(CastError::IntToFatCast),
341 Some(PointerKind::Length) => Err(CastError::IntToFatCast),
342 Some(PointerKind::OfAlias | PointerKind::OfParam(_)) => Err(CastError::IntToFatCast),
343 }
344 }
345
346 fn check_fptr_ptr_cast(
347 &self,
348 ctx: &mut InferenceContext<'_, 'db>,
349 cast_ty: Ty<'db>,
350 ) -> Result<(), CastError> {
351 match pointer_kind(cast_ty, ctx).map_err(|_| CastError::Unknown)? {
352 None => Ok(()),
354 Some(PointerKind::Error) => Ok(()),
355 Some(PointerKind::Thin) => Ok(()),
356 _ => Err(CastError::IllegalCast),
357 }
358 }
359}
360
361#[derive(Debug, PartialEq, Eq)]
362enum PointerKind<'db> {
363 Thin,
365 VTable(BoundExistentialPredicates<'db>),
367 Length,
369 OfAlias,
370 OfParam(ParamTy),
371 Error,
372}
373
374fn pointer_kind<'db>(
375 ty: Ty<'db>,
376 ctx: &mut InferenceContext<'_, 'db>,
377) -> Result<Option<PointerKind<'db>>, ()> {
378 let ty = ctx.table.try_structurally_resolve_type(ty);
379
380 if ctx.table.is_sized(ty) {
381 return Ok(Some(PointerKind::Thin));
382 }
383
384 match ty.kind() {
385 TyKind::Slice(_) | TyKind::Str => Ok(Some(PointerKind::Length)),
386 TyKind::Dynamic(bounds, _) => Ok(Some(PointerKind::VTable(bounds))),
387 TyKind::Adt(adt_def, subst) => {
388 let id = adt_def.def_id().0;
389 let AdtId::StructId(id) = id else {
390 never!("`{:?}` should be sized but is not?", ty);
391 return Err(());
392 };
393
394 let struct_data = id.fields(ctx.db);
395 if let Some((last_field, _)) = struct_data.fields().iter().last() {
396 let last_field_ty =
397 ctx.db.field_types(id.into())[last_field].instantiate(ctx.interner(), subst);
398 pointer_kind(last_field_ty, ctx)
399 } else {
400 Ok(Some(PointerKind::Thin))
401 }
402 }
403 TyKind::Tuple(subst) => match subst.iter().next_back() {
404 None => Ok(Some(PointerKind::Thin)),
405 Some(ty) => pointer_kind(ty, ctx),
406 },
407 TyKind::Foreign(_) => Ok(Some(PointerKind::Thin)),
408 TyKind::Alias(..) => Ok(Some(PointerKind::OfAlias)),
409 TyKind::Error(_) => Ok(Some(PointerKind::Error)),
410 TyKind::Param(idx) => Ok(Some(PointerKind::OfParam(idx))),
411 TyKind::Bound(..) | TyKind::Placeholder(..) | TyKind::Infer(..) => Ok(None),
412 TyKind::Int(_)
413 | TyKind::Uint(_)
414 | TyKind::Float(_)
415 | TyKind::Bool
416 | TyKind::Char
417 | TyKind::Array(..)
418 | TyKind::CoroutineWitness(..)
419 | TyKind::RawPtr(..)
420 | TyKind::Ref(..)
421 | TyKind::FnDef(..)
422 | TyKind::FnPtr(..)
423 | TyKind::Closure(..)
424 | TyKind::Coroutine(..)
425 | TyKind::CoroutineClosure(..)
426 | TyKind::Never => {
427 never!("`{:?}` should be sized but is not?", ty);
428 Err(())
429 }
430 TyKind::UnsafeBinder(..) | TyKind::Pat(..) => {
431 never!("we don't produce these types: {ty:?}");
432 Err(())
433 }
434 }
435}
436
437fn contains_dyn_trait<'db>(ty: Ty<'db>) -> bool {
438 use std::ops::ControlFlow;
439
440 use rustc_type_ir::{TypeSuperVisitable, TypeVisitable, TypeVisitor};
441
442 struct DynTraitVisitor;
443
444 impl<'db> TypeVisitor<DbInterner<'db>> for DynTraitVisitor {
445 type Result = ControlFlow<()>;
446
447 fn visit_ty(&mut self, ty: Ty<'db>) -> ControlFlow<()> {
448 match ty.kind() {
449 TyKind::Dynamic(..) => ControlFlow::Break(()),
450 _ => ty.super_visit_with(self),
451 }
452 }
453 }
454
455 ty.visit_with(&mut DynTraitVisitor).is_break()
456}