1use std::{cmp, ops::Bound};
4
5use hir_def::{
6 AdtId, VariantId,
7 attrs::AttrFlags,
8 signatures::{StructFlags, VariantFields},
9};
10use rustc_abi::{Integer, ReprOptions, TargetDataLayout};
11use rustc_index::IndexVec;
12use smallvec::SmallVec;
13use triomphe::Arc;
14
15use crate::{
16 ParamEnvAndCrate,
17 db::HirDatabase,
18 layout::{Layout, LayoutCx, LayoutError, field_ty},
19 next_solver::GenericArgs,
20};
21
22pub fn layout_of_adt_query<'db>(
23 db: &'db dyn HirDatabase,
24 def: AdtId,
25 args: GenericArgs<'db>,
26 trait_env: ParamEnvAndCrate<'db>,
27) -> Result<Arc<Layout>, LayoutError> {
28 let krate = trait_env.krate;
29 let Ok(target) = db.target_data_layout(krate) else {
30 return Err(LayoutError::TargetLayoutNotAvailable);
31 };
32 let dl = &*target;
33 let cx = LayoutCx::new(dl);
34 let handle_variant = |def: VariantId, var: &VariantFields| {
35 var.fields()
36 .iter()
37 .map(|(fd, _)| db.layout_of_ty(field_ty(db, def, fd, &args), trait_env))
38 .collect::<Result<Vec<_>, _>>()
39 };
40 let (variants, repr, is_special_no_niche) = match def {
41 AdtId::StructId(s) => {
42 let sig = db.struct_signature(s);
43 let mut r = SmallVec::<[_; 1]>::new();
44 r.push(handle_variant(s.into(), s.fields(db))?);
45 (
46 r,
47 AttrFlags::repr(db, s.into()).unwrap_or_default(),
48 sig.flags.intersects(StructFlags::IS_UNSAFE_CELL | StructFlags::IS_UNSAFE_PINNED),
49 )
50 }
51 AdtId::UnionId(id) => {
52 let repr = AttrFlags::repr(db, id.into());
53 let mut r = SmallVec::new();
54 r.push(handle_variant(id.into(), id.fields(db))?);
55 (r, repr.unwrap_or_default(), false)
56 }
57 AdtId::EnumId(e) => {
58 let variants = e.enum_variants(db);
59 let r = variants
60 .variants
61 .iter()
62 .map(|&(v, _, _)| handle_variant(v.into(), v.fields(db)))
63 .collect::<Result<SmallVec<_>, _>>()?;
64 (r, AttrFlags::repr(db, e.into()).unwrap_or_default(), false)
65 }
66 };
67 let variants = variants
68 .iter()
69 .map(|it| it.iter().map(|it| &**it).collect::<Vec<_>>())
70 .collect::<SmallVec<[_; 1]>>();
71 let variants = variants.iter().map(|it| it.iter().collect()).collect::<IndexVec<_, _>>();
72 let result = if matches!(def, AdtId::UnionId(..)) {
73 cx.calc.layout_of_union(&repr, &variants)?
74 } else {
75 cx.calc.layout_of_struct_or_enum(
76 &repr,
77 &variants,
78 matches!(def, AdtId::EnumId(..)),
79 is_special_no_niche,
80 layout_scalar_valid_range(db, def),
81 |min, max| repr_discr(dl, &repr, min, max).unwrap_or((Integer::I8, false)),
82 variants.iter_enumerated().filter_map(|(id, _)| {
83 let AdtId::EnumId(e) = def else { return None };
84 let d = db.const_eval_discriminant(e.enum_variants(db).variants[id.0].0).ok()?;
85 Some((id, d))
86 }),
87 !matches!(def, AdtId::EnumId(..))
88 && variants
89 .iter()
90 .next()
91 .and_then(|it| it.iter().last().map(|it| !it.is_unsized()))
92 .unwrap_or(true),
93 )?
94 };
95 Ok(Arc::new(result))
96}
97
98pub(crate) fn layout_of_adt_cycle_result<'db>(
99 _: &'db dyn HirDatabase,
100 _def: AdtId,
101 _args: GenericArgs<'db>,
102 _trait_env: ParamEnvAndCrate<'db>,
103) -> Result<Arc<Layout>, LayoutError> {
104 Err(LayoutError::RecursiveTypeWithoutIndirection)
105}
106
107fn layout_scalar_valid_range(db: &dyn HirDatabase, def: AdtId) -> (Bound<u128>, Bound<u128>) {
108 let range = AttrFlags::rustc_layout_scalar_valid_range(db, def);
109 let get = |value| match value {
110 Some(it) => Bound::Included(it),
111 None => Bound::Unbounded,
112 };
113 (get(range.start), get(range.end))
114}
115
116fn repr_discr(
121 dl: &TargetDataLayout,
122 repr: &ReprOptions,
123 min: i128,
124 max: i128,
125) -> Result<(Integer, bool), LayoutError> {
126 let unsigned_fit = Integer::fit_unsigned(cmp::max(min as u128, max as u128));
131 let signed_fit = cmp::max(Integer::fit_signed(min), Integer::fit_signed(max));
132
133 if let Some(ity) = repr.int {
134 let discr = Integer::from_attr(dl, ity);
135 let fit = if ity.is_signed() { signed_fit } else { unsigned_fit };
136 if discr < fit {
137 return Err(LayoutError::UserReprTooSmall);
138 }
139 return Ok((discr, ity.is_signed()));
140 }
141
142 let at_least = if repr.c() {
143 dl.c_enum_min_size
146 } else {
147 Integer::I8
149 };
150
151 Ok(if min >= 0 {
153 (cmp::max(unsigned_fit, at_least), false)
154 } else {
155 (cmp::max(signed_fit, at_least), true)
156 })
157}