hir_ty/mir/eval/shim/
simd.rs

1//! Shim implementation for simd intrinsics
2
3use std::cmp::Ordering;
4
5use crate::TyKind;
6use crate::consteval::try_const_usize;
7
8use super::*;
9
10macro_rules! from_bytes {
11    ($ty:tt, $value:expr) => {
12        ($ty::from_le_bytes(match ($value).try_into() {
13            Ok(it) => it,
14            Err(_) => return Err(MirEvalError::InternalError("mismatched size".into())),
15        }))
16    };
17}
18
19macro_rules! not_supported {
20    ($it: expr) => {
21        return Err(MirEvalError::NotSupported(format!($it)))
22    };
23}
24
25impl Evaluator<'_> {
26    fn detect_simd_ty(&self, ty: &Ty) -> Result<(usize, Ty)> {
27        match ty.kind(Interner) {
28            TyKind::Adt(id, subst) => {
29                let len = match subst.as_slice(Interner).get(1).and_then(|it| it.constant(Interner))
30                {
31                    Some(len) => len,
32                    _ => {
33                        if let AdtId::StructId(id) = id.0 {
34                            let struct_data = self.db.variant_fields(id.into());
35                            let fields = struct_data.fields();
36                            let Some((first_field, _)) = fields.iter().next() else {
37                                not_supported!("simd type with no field");
38                            };
39                            let field_ty = self.db.field_types(id.into())[first_field]
40                                .clone()
41                                .substitute(Interner, subst);
42                            return Ok((fields.len(), field_ty));
43                        }
44                        return Err(MirEvalError::InternalError(
45                            "simd type with no len param".into(),
46                        ));
47                    }
48                };
49                match try_const_usize(self.db, len) {
50                    Some(len) => {
51                        let Some(ty) =
52                            subst.as_slice(Interner).first().and_then(|it| it.ty(Interner))
53                        else {
54                            return Err(MirEvalError::InternalError(
55                                "simd type with no ty param".into(),
56                            ));
57                        };
58                        Ok((len as usize, ty.clone()))
59                    }
60                    None => Err(MirEvalError::InternalError(
61                        "simd type with unevaluatable len param".into(),
62                    )),
63                }
64            }
65            _ => Err(MirEvalError::InternalError("simd type which is not a struct".into())),
66        }
67    }
68
69    pub(super) fn exec_simd_intrinsic(
70        &mut self,
71        name: &str,
72        args: &[IntervalAndTy],
73        _generic_args: &Substitution,
74        destination: Interval,
75        _locals: &Locals,
76        _span: MirSpan,
77    ) -> Result<()> {
78        match name {
79            "and" | "or" | "xor" => {
80                let [left, right] = args else {
81                    return Err(MirEvalError::InternalError(
82                        "simd bit op args are not provided".into(),
83                    ));
84                };
85                let result = left
86                    .get(self)?
87                    .iter()
88                    .zip(right.get(self)?)
89                    .map(|(&it, &y)| match name {
90                        "and" => it & y,
91                        "or" => it | y,
92                        "xor" => it ^ y,
93                        _ => unreachable!(),
94                    })
95                    .collect::<Vec<_>>();
96                destination.write_from_bytes(self, &result)
97            }
98            "eq" | "ne" | "lt" | "le" | "gt" | "ge" => {
99                let [left, right] = args else {
100                    return Err(MirEvalError::InternalError("simd args are not provided".into()));
101                };
102                let (len, ty) = self.detect_simd_ty(&left.ty)?;
103                let is_signed = matches!(ty.as_builtin(), Some(BuiltinType::Int(_)));
104                let size = left.interval.size / len;
105                let dest_size = destination.size / len;
106                let mut destination_bytes = vec![];
107                let vector = left.get(self)?.chunks(size).zip(right.get(self)?.chunks(size));
108                for (l, r) in vector {
109                    let mut result = Ordering::Equal;
110                    for (l, r) in l.iter().zip(r).rev() {
111                        let it = l.cmp(r);
112                        if it != Ordering::Equal {
113                            result = it;
114                            break;
115                        }
116                    }
117                    if is_signed {
118                        if let Some((&l, &r)) = l.iter().zip(r).next_back() {
119                            if l != r {
120                                result = (l as i8).cmp(&(r as i8));
121                            }
122                        }
123                    }
124                    let result = match result {
125                        Ordering::Less => ["lt", "le", "ne"].contains(&name),
126                        Ordering::Equal => ["ge", "le", "eq"].contains(&name),
127                        Ordering::Greater => ["ge", "gt", "ne"].contains(&name),
128                    };
129                    let result = if result { 255 } else { 0 };
130                    destination_bytes.extend(std::iter::repeat_n(result, dest_size));
131                }
132
133                destination.write_from_bytes(self, &destination_bytes)
134            }
135            "bitmask" => {
136                let [op] = args else {
137                    return Err(MirEvalError::InternalError(
138                        "simd_bitmask args are not provided".into(),
139                    ));
140                };
141                let (op_len, _) = self.detect_simd_ty(&op.ty)?;
142                let op_count = op.interval.size / op_len;
143                let mut result: u64 = 0;
144                for (i, val) in op.get(self)?.chunks(op_count).enumerate() {
145                    if !val.iter().all(|&it| it == 0) {
146                        result |= 1 << i;
147                    }
148                }
149                destination.write_from_bytes(self, &result.to_le_bytes()[0..destination.size])
150            }
151            "shuffle" => {
152                let [left, right, index] = args else {
153                    return Err(MirEvalError::InternalError(
154                        "simd_shuffle args are not provided".into(),
155                    ));
156                };
157                let TyKind::Array(_, index_len) = index.ty.kind(Interner) else {
158                    return Err(MirEvalError::InternalError(
159                        "simd_shuffle index argument has non-array type".into(),
160                    ));
161                };
162                let index_len = match try_const_usize(self.db, index_len) {
163                    Some(it) => it as usize,
164                    None => {
165                        return Err(MirEvalError::InternalError(
166                            "simd type with unevaluatable len param".into(),
167                        ));
168                    }
169                };
170                let (left_len, _) = self.detect_simd_ty(&left.ty)?;
171                let left_size = left.interval.size / left_len;
172                let vector =
173                    left.get(self)?.chunks(left_size).chain(right.get(self)?.chunks(left_size));
174                let mut result = vec![];
175                for index in index.get(self)?.chunks(index.interval.size / index_len) {
176                    let index = from_bytes!(u32, index) as usize;
177                    let val = match vector.clone().nth(index) {
178                        Some(it) => it,
179                        None => {
180                            return Err(MirEvalError::InternalError(
181                                "out of bound access in simd shuffle".into(),
182                            ));
183                        }
184                    };
185                    result.extend(val);
186                }
187                destination.write_from_bytes(self, &result)
188            }
189            _ => not_supported!("unknown simd intrinsic {name}"),
190        }
191    }
192}