1use chalk_ir::{Mutability, Scalar, TyVariableKind, UintTy};
4use hir_def::{AdtId, hir::ExprId, signatures::TraitFlags};
5use stdx::never;
6
7use crate::{
8 Adjustment, Binders, DynTy, InferenceDiagnostic, Interner, PlaceholderIndex,
9 QuantifiedWhereClauses, Ty, TyExt, TyKind, TypeFlags, WhereClause,
10 db::HirDatabase,
11 from_chalk_trait_id,
12 infer::{coerce::CoerceNever, unify::InferenceTable},
13};
14
15#[derive(Debug)]
16pub(crate) enum Int {
17 I,
18 U(UintTy),
19 Bool,
20 Char,
21 CEnum,
22 InferenceVar,
23}
24
25#[derive(Debug)]
26pub(crate) enum CastTy {
27 Int(Int),
28 Float,
29 FnPtr,
30 Ptr(Ty, Mutability),
31 }
33
34impl CastTy {
35 pub(crate) fn from_ty(db: &dyn HirDatabase, t: &Ty) -> Option<Self> {
36 match t.kind(Interner) {
37 TyKind::Scalar(Scalar::Bool) => Some(Self::Int(Int::Bool)),
38 TyKind::Scalar(Scalar::Char) => Some(Self::Int(Int::Char)),
39 TyKind::Scalar(Scalar::Int(_)) => Some(Self::Int(Int::I)),
40 TyKind::Scalar(Scalar::Uint(it)) => Some(Self::Int(Int::U(*it))),
41 TyKind::InferenceVar(_, TyVariableKind::Integer) => Some(Self::Int(Int::InferenceVar)),
42 TyKind::InferenceVar(_, TyVariableKind::Float) => Some(Self::Float),
43 TyKind::Scalar(Scalar::Float(_)) => Some(Self::Float),
44 TyKind::Adt(..) => {
45 let (AdtId::EnumId(id), _) = t.as_adt()? else {
46 return None;
47 };
48 let enum_data = id.enum_variants(db);
49 if enum_data.is_payload_free(db) { Some(Self::Int(Int::CEnum)) } else { None }
50 }
51 TyKind::Raw(m, ty) => Some(Self::Ptr(ty.clone(), *m)),
52 TyKind::Function(_) => Some(Self::FnPtr),
53 _ => None,
54 }
55 }
56}
57
58#[derive(Debug, PartialEq, Eq, Clone, Copy)]
59pub enum CastError {
60 Unknown,
61 CastToBool,
62 CastToChar,
63 DifferingKinds,
64 SizedUnsizedCast,
65 IllegalCast,
66 IntToFatCast,
67 NeedDeref,
68 NeedViaPtr,
69 NeedViaThinPtr,
70 NeedViaInt,
71 NonScalar,
72 }
76
77impl CastError {
78 fn into_diagnostic(self, expr: ExprId, expr_ty: Ty, cast_ty: Ty) -> InferenceDiagnostic {
79 InferenceDiagnostic::InvalidCast { expr, error: self, expr_ty, cast_ty }
80 }
81}
82
83#[derive(Clone, Debug)]
84pub(super) struct CastCheck {
85 expr: ExprId,
86 source_expr: ExprId,
87 expr_ty: Ty,
88 cast_ty: Ty,
89}
90
91impl CastCheck {
92 pub(super) fn new(expr: ExprId, source_expr: ExprId, expr_ty: Ty, cast_ty: Ty) -> Self {
93 Self { expr, source_expr, expr_ty, cast_ty }
94 }
95
96 pub(super) fn check<F, G>(
97 &mut self,
98 table: &mut InferenceTable<'_>,
99 apply_adjustments: &mut F,
100 set_coercion_cast: &mut G,
101 ) -> Result<(), InferenceDiagnostic>
102 where
103 F: FnMut(ExprId, Vec<Adjustment>),
104 G: FnMut(ExprId),
105 {
106 self.expr_ty = table.eagerly_normalize_and_resolve_shallow_in(self.expr_ty.clone());
107 self.cast_ty = table.eagerly_normalize_and_resolve_shallow_in(self.cast_ty.clone());
108
109 if self.expr_ty.contains_unknown() || self.cast_ty.contains_unknown() {
110 return Ok(());
111 }
112
113 if !self.cast_ty.data(Interner).flags.contains(TypeFlags::HAS_TY_INFER)
114 && !table.is_sized(&self.cast_ty)
115 {
116 return Err(InferenceDiagnostic::CastToUnsized {
117 expr: self.expr,
118 cast_ty: self.cast_ty.clone(),
119 });
120 }
121
122 if contains_dyn_trait(&self.cast_ty) {
126 return Ok(());
127 }
128
129 if let Ok((adj, _)) = table.coerce(&self.expr_ty, &self.cast_ty, CoerceNever::Yes) {
130 apply_adjustments(self.source_expr, adj);
131 set_coercion_cast(self.source_expr);
132 return Ok(());
133 }
134
135 self.do_check(table, apply_adjustments)
136 .map_err(|e| e.into_diagnostic(self.expr, self.expr_ty.clone(), self.cast_ty.clone()))
137 }
138
139 fn do_check<F>(
140 &self,
141 table: &mut InferenceTable<'_>,
142 apply_adjustments: &mut F,
143 ) -> Result<(), CastError>
144 where
145 F: FnMut(ExprId, Vec<Adjustment>),
146 {
147 let (t_from, t_cast) = match (
148 CastTy::from_ty(table.db, &self.expr_ty),
149 CastTy::from_ty(table.db, &self.cast_ty),
150 ) {
151 (Some(t_from), Some(t_cast)) => (t_from, t_cast),
152 (None, Some(t_cast)) => match self.expr_ty.kind(Interner) {
153 TyKind::FnDef(..) => {
154 let sig = self.expr_ty.callable_sig(table.db).expect("FnDef had no sig");
155 let sig = table.eagerly_normalize_and_resolve_shallow_in(sig);
156 let fn_ptr = TyKind::Function(sig.to_fn_ptr()).intern(Interner);
157 if let Ok((adj, _)) = table.coerce(&self.expr_ty, &fn_ptr, CoerceNever::Yes) {
158 apply_adjustments(self.source_expr, adj);
159 } else {
160 return Err(CastError::IllegalCast);
161 }
162
163 (CastTy::FnPtr, t_cast)
164 }
165 TyKind::Ref(mutbl, _, inner_ty) => {
166 return match t_cast {
167 CastTy::Int(_) | CastTy::Float => match inner_ty.kind(Interner) {
168 TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_))
169 | TyKind::InferenceVar(
170 _,
171 TyVariableKind::Integer | TyVariableKind::Float,
172 ) => Err(CastError::NeedDeref),
173
174 _ => Err(CastError::NeedViaPtr),
175 },
176 CastTy::Ptr(t, m) => {
178 let t = table.eagerly_normalize_and_resolve_shallow_in(t);
179 if !table.is_sized(&t) {
180 return Err(CastError::IllegalCast);
181 }
182 self.check_ref_cast(table, inner_ty, *mutbl, &t, m, apply_adjustments)
183 }
184 _ => Err(CastError::NonScalar),
185 };
186 }
187 _ => return Err(CastError::NonScalar),
188 },
189 _ => return Err(CastError::NonScalar),
190 };
191
192 match (t_from, t_cast) {
195 (_, CastTy::Int(Int::CEnum) | CastTy::FnPtr) => Err(CastError::NonScalar),
196 (_, CastTy::Int(Int::Bool)) => Err(CastError::CastToBool),
197 (CastTy::Int(Int::U(UintTy::U8)), CastTy::Int(Int::Char)) => Ok(()),
198 (_, CastTy::Int(Int::Char)) => Err(CastError::CastToChar),
199 (CastTy::Int(Int::Bool | Int::CEnum | Int::Char), CastTy::Float) => {
200 Err(CastError::NeedViaInt)
201 }
202 (CastTy::Int(Int::Bool | Int::CEnum | Int::Char) | CastTy::Float, CastTy::Ptr(..))
203 | (CastTy::Ptr(..) | CastTy::FnPtr, CastTy::Float) => Err(CastError::IllegalCast),
204 (CastTy::Ptr(src, _), CastTy::Ptr(dst, _)) => {
205 self.check_ptr_ptr_cast(table, &src, &dst)
206 }
207 (CastTy::Ptr(src, _), CastTy::Int(_)) => self.check_ptr_addr_cast(table, &src),
208 (CastTy::Int(_), CastTy::Ptr(dst, _)) => self.check_addr_ptr_cast(table, &dst),
209 (CastTy::FnPtr, CastTy::Ptr(dst, _)) => self.check_fptr_ptr_cast(table, &dst),
210 (CastTy::Int(Int::CEnum), CastTy::Int(_)) => Ok(()),
211 (CastTy::Int(Int::Char | Int::Bool), CastTy::Int(_)) => Ok(()),
212 (CastTy::Int(_) | CastTy::Float, CastTy::Int(_) | CastTy::Float) => Ok(()),
213 (CastTy::FnPtr, CastTy::Int(_)) => Ok(()),
214 }
215 }
216
217 fn check_ref_cast<F>(
218 &self,
219 table: &mut InferenceTable<'_>,
220 t_expr: &Ty,
221 m_expr: Mutability,
222 t_cast: &Ty,
223 m_cast: Mutability,
224 apply_adjustments: &mut F,
225 ) -> Result<(), CastError>
226 where
227 F: FnMut(ExprId, Vec<Adjustment>),
228 {
229 if m_expr <= m_cast
231 && let TyKind::Array(ety, _) = t_expr.kind(Interner)
232 {
233 let array_ptr_type = TyKind::Raw(m_expr, t_expr.clone()).intern(Interner);
235 if let Ok((adj, _)) = table.coerce(&self.expr_ty, &array_ptr_type, CoerceNever::Yes) {
236 apply_adjustments(self.source_expr, adj);
237 } else {
238 never!(
239 "could not cast from reference to array to pointer to array ({:?} to {:?})",
240 self.expr_ty,
241 array_ptr_type
242 );
243 }
244
245 if table.coerce(ety, t_cast, CoerceNever::Yes).is_ok() {
248 return Ok(());
249 }
250 }
251
252 Err(CastError::IllegalCast)
253 }
254
255 fn check_ptr_ptr_cast(
256 &self,
257 table: &mut InferenceTable<'_>,
258 src: &Ty,
259 dst: &Ty,
260 ) -> Result<(), CastError> {
261 let src_kind = pointer_kind(src, table).map_err(|_| CastError::Unknown)?;
262 let dst_kind = pointer_kind(dst, table).map_err(|_| CastError::Unknown)?;
263
264 match (src_kind, dst_kind) {
265 (Some(PointerKind::Error), _) | (_, Some(PointerKind::Error)) => Ok(()),
266 (_, None) | (None, _) => Ok(()),
269 (_, Some(PointerKind::Thin)) => Ok(()),
270 (Some(PointerKind::Thin), _) => Err(CastError::SizedUnsizedCast),
271 (Some(PointerKind::VTable(src_tty)), Some(PointerKind::VTable(dst_tty))) => {
272 let principal = |tty: &Binders<QuantifiedWhereClauses>| {
273 tty.skip_binders().as_slice(Interner).first().and_then(|pred| {
274 if let WhereClause::Implemented(tr) = pred.skip_binders() {
275 Some(tr.trait_id)
276 } else {
277 None
278 }
279 })
280 };
281 match (principal(&src_tty), principal(&dst_tty)) {
282 (Some(src_principal), Some(dst_principal)) => {
283 if src_principal == dst_principal {
284 return Ok(());
285 }
286 let src_principal =
287 table.db.trait_signature(from_chalk_trait_id(src_principal));
288 let dst_principal =
289 table.db.trait_signature(from_chalk_trait_id(dst_principal));
290 if src_principal.flags.contains(TraitFlags::AUTO)
291 && dst_principal.flags.contains(TraitFlags::AUTO)
292 {
293 Ok(())
294 } else {
295 Err(CastError::DifferingKinds)
296 }
297 }
298 _ => Err(CastError::Unknown),
299 }
300 }
301 (Some(src_kind), Some(dst_kind)) if src_kind == dst_kind => Ok(()),
302 (_, _) => Err(CastError::DifferingKinds),
303 }
304 }
305
306 fn check_ptr_addr_cast(
307 &self,
308 table: &mut InferenceTable<'_>,
309 expr_ty: &Ty,
310 ) -> Result<(), CastError> {
311 match pointer_kind(expr_ty, table).map_err(|_| CastError::Unknown)? {
312 None => Ok(()),
314 Some(PointerKind::Error) => Ok(()),
315 Some(PointerKind::Thin) => Ok(()),
316 _ => Err(CastError::NeedViaThinPtr),
317 }
318 }
319
320 fn check_addr_ptr_cast(
321 &self,
322 table: &mut InferenceTable<'_>,
323 cast_ty: &Ty,
324 ) -> Result<(), CastError> {
325 match pointer_kind(cast_ty, table).map_err(|_| CastError::Unknown)? {
326 None => Ok(()),
328 Some(PointerKind::Error) => Ok(()),
329 Some(PointerKind::Thin) => Ok(()),
330 Some(PointerKind::VTable(_)) => Err(CastError::IntToFatCast),
331 Some(PointerKind::Length) => Err(CastError::IntToFatCast),
332 Some(PointerKind::OfAlias | PointerKind::OfParam(_)) => Err(CastError::IntToFatCast),
333 }
334 }
335
336 fn check_fptr_ptr_cast(
337 &self,
338 table: &mut InferenceTable<'_>,
339 cast_ty: &Ty,
340 ) -> Result<(), CastError> {
341 match pointer_kind(cast_ty, table).map_err(|_| CastError::Unknown)? {
342 None => Ok(()),
344 Some(PointerKind::Error) => Ok(()),
345 Some(PointerKind::Thin) => Ok(()),
346 _ => Err(CastError::IllegalCast),
347 }
348 }
349}
350
351#[derive(Debug, PartialEq, Eq)]
352enum PointerKind {
353 Thin,
355 VTable(Binders<QuantifiedWhereClauses>),
357 Length,
359 OfAlias,
360 OfParam(PlaceholderIndex),
361 Error,
362}
363
364fn pointer_kind(ty: &Ty, table: &mut InferenceTable<'_>) -> Result<Option<PointerKind>, ()> {
365 let ty = table.eagerly_normalize_and_resolve_shallow_in(ty.clone());
366
367 if table.is_sized(&ty) {
368 return Ok(Some(PointerKind::Thin));
369 }
370
371 match ty.kind(Interner) {
372 TyKind::Slice(_) | TyKind::Str => Ok(Some(PointerKind::Length)),
373 TyKind::Dyn(DynTy { bounds, .. }) => Ok(Some(PointerKind::VTable(bounds.clone()))),
374 TyKind::Adt(chalk_ir::AdtId(id), subst) => {
375 let AdtId::StructId(id) = *id else {
376 never!("`{:?}` should be sized but is not?", ty);
377 return Err(());
378 };
379
380 let struct_data = id.fields(table.db);
381 if let Some((last_field, _)) = struct_data.fields().iter().last() {
382 let last_field_ty =
383 table.db.field_types(id.into())[last_field].clone().substitute(Interner, subst);
384 pointer_kind(&last_field_ty, table)
385 } else {
386 Ok(Some(PointerKind::Thin))
387 }
388 }
389 TyKind::Tuple(_, subst) => {
390 match subst.iter(Interner).last().and_then(|arg| arg.ty(Interner)) {
391 None => Ok(Some(PointerKind::Thin)),
392 Some(ty) => pointer_kind(ty, table),
393 }
394 }
395 TyKind::Foreign(_) => Ok(Some(PointerKind::Thin)),
396 TyKind::Alias(_) | TyKind::AssociatedType(..) | TyKind::OpaqueType(..) => {
397 Ok(Some(PointerKind::OfAlias))
398 }
399 TyKind::Error => Ok(Some(PointerKind::Error)),
400 TyKind::Placeholder(idx) => Ok(Some(PointerKind::OfParam(*idx))),
401 TyKind::BoundVar(_) | TyKind::InferenceVar(..) => Ok(None),
402 TyKind::Scalar(_)
403 | TyKind::Array(..)
404 | TyKind::CoroutineWitness(..)
405 | TyKind::Raw(..)
406 | TyKind::Ref(..)
407 | TyKind::FnDef(..)
408 | TyKind::Function(_)
409 | TyKind::Closure(..)
410 | TyKind::Coroutine(..)
411 | TyKind::Never => {
412 never!("`{:?}` should be sized but is not?", ty);
413 Err(())
414 }
415 }
416}
417
418fn contains_dyn_trait(ty: &Ty) -> bool {
419 use std::ops::ControlFlow;
420
421 use chalk_ir::{
422 DebruijnIndex,
423 visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor},
424 };
425
426 struct DynTraitVisitor;
427
428 impl TypeVisitor<Interner> for DynTraitVisitor {
429 type BreakTy = ();
430
431 fn as_dyn(&mut self) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> {
432 self
433 }
434
435 fn interner(&self) -> Interner {
436 Interner
437 }
438
439 fn visit_ty(&mut self, ty: &Ty, outer_binder: DebruijnIndex) -> ControlFlow<Self::BreakTy> {
440 match ty.kind(Interner) {
441 TyKind::Dyn(_) => ControlFlow::Break(()),
442 _ => ty.super_visit_with(self.as_dyn(), outer_binder),
443 }
444 }
445 }
446
447 ty.visit_with(DynTraitVisitor.as_dyn(), DebruijnIndex::INNERMOST).is_break()
448}