core_simd/
ops.rs

1use crate::simd::{Select, Simd, SimdElement, cmp::SimdPartialEq};
2use core::ops::{Add, Mul};
3use core::ops::{BitAnd, BitOr, BitXor};
4use core::ops::{Div, Rem, Sub};
5use core::ops::{Shl, Shr};
6
7mod assign;
8mod deref;
9mod shift_scalar;
10mod unary;
11
12impl<I, T, const N: usize> core::ops::Index<I> for Simd<T, N>
13where
14    T: SimdElement,
15    I: core::slice::SliceIndex<[T]>,
16{
17    type Output = I::Output;
18    #[inline]
19    fn index(&self, index: I) -> &Self::Output {
20        &self.as_array()[index]
21    }
22}
23
24impl<I, T, const N: usize> core::ops::IndexMut<I> for Simd<T, N>
25where
26    T: SimdElement,
27    I: core::slice::SliceIndex<[T]>,
28{
29    #[inline]
30    fn index_mut(&mut self, index: I) -> &mut Self::Output {
31        &mut self.as_mut_array()[index]
32    }
33}
34
35macro_rules! unsafe_base {
36    ($lhs:ident, $rhs:ident, {$simd_call:ident}, $($_:tt)*) => {
37        // Safety: $lhs and $rhs are vectors
38        unsafe { core::intrinsics::simd::$simd_call($lhs, $rhs) }
39    };
40}
41
42/// SAFETY: This macro should not be used for anything except Shl or Shr, and passed the appropriate shift intrinsic.
43/// It handles performing a bitand in addition to calling the shift operator, so that the result
44/// is well-defined: LLVM can return a poison value if you shl, lshr, or ashr if `rhs >= <Int>::BITS`
45/// At worst, this will maybe add another instruction and cycle,
46/// at best, it may open up more optimization opportunities,
47/// or simply be elided entirely, especially for SIMD ISAs which default to this.
48///
49// FIXME: Consider implementing this in cg_llvm instead?
50// cg_clif defaults to this, and scalar MIR shifts also default to wrapping
51macro_rules! wrap_bitshift {
52    ($lhs:ident, $rhs:ident, {$simd_call:ident}, $int:ident) => {
53        #[allow(clippy::suspicious_arithmetic_impl)]
54        // Safety: $lhs and the bitand result are vectors
55        unsafe {
56            core::intrinsics::simd::$simd_call(
57                $lhs,
58                $rhs.bitand(Simd::splat(<$int>::BITS as $int - 1)),
59            )
60        }
61    };
62}
63
64/// SAFETY: This macro must only be used to impl Div or Rem and given the matching intrinsic.
65/// It guards against LLVM's UB conditions for integer div or rem using masks and selects,
66/// thus guaranteeing a Rust value returns instead.
67///
68/// |                  | LLVM | Rust
69/// | :--------------: | :--- | :----------
70/// | N {/,%} 0        | UB   | panic!()
71/// | <$int>::MIN / -1 | UB   | <$int>::MIN
72/// | <$int>::MIN % -1 | UB   | 0
73///
74macro_rules! int_divrem_guard {
75    (   $lhs:ident,
76        $rhs:ident,
77        {   const PANIC_ZERO: &'static str = $zero:literal;
78            $simd_call:ident, $op:tt
79        },
80        $int:ident ) => {
81        if $rhs.simd_eq(Simd::splat(0 as _)).any() {
82            panic!($zero);
83        } else {
84            // Prevent otherwise-UB overflow on the MIN / -1 case.
85            let rhs = if <$int>::MIN != 0 {
86                // This should, at worst, optimize to a few branchless logical ops
87                // Ideally, this entire conditional should evaporate
88                // Fire LLVM and implement those manually if it doesn't get the hint
89                ($lhs.simd_eq(Simd::splat(<$int>::MIN))
90                // type inference can break here, so cut an SInt to size
91                & $rhs.simd_eq(Simd::splat(-1i64 as _)))
92                .select(Simd::splat(1 as _), $rhs)
93            } else {
94                // Nice base case to make it easy to const-fold away the other branch.
95                $rhs
96            };
97
98            // aarch64 div fails for arbitrary `v % 0`, mod fails when rhs is MIN, for non-powers-of-two
99            // these operations aren't vectorized on aarch64 anyway
100            #[cfg(target_arch = "aarch64")]
101            {
102                let mut out = Simd::splat(0 as _);
103                for i in 0..Self::LEN {
104                    out[i] = $lhs[i] $op rhs[i];
105                }
106                out
107            }
108
109            #[cfg(not(target_arch = "aarch64"))]
110            {
111                // Safety: $lhs and rhs are vectors
112                unsafe { core::intrinsics::simd::$simd_call($lhs, rhs) }
113            }
114        }
115    };
116}
117
118macro_rules! for_base_types {
119    (   T = ($($scalar:ident),*);
120        type Lhs = Simd<T, N>;
121        type Rhs = Simd<T, N>;
122        type Output = $out:ty;
123
124        impl $op:ident::$call:ident {
125            $macro_impl:ident $inner:tt
126        }) => {
127            $(
128                impl<const N: usize> $op<Self> for Simd<$scalar, N>
129                where
130                    $scalar: SimdElement,
131                {
132                    type Output = $out;
133
134                    #[inline]
135                    // TODO: only useful for int Div::div, but we hope that this
136                    // will essentially always get inlined anyway.
137                    #[track_caller]
138                    fn $call(self, rhs: Self) -> Self::Output {
139                        $macro_impl!(self, rhs, $inner, $scalar)
140                    }
141                }
142            )*
143    }
144}
145
146// A "TokenTree muncher": takes a set of scalar types `T = {};`
147// type parameters for the ops it implements, `Op::fn` names,
148// and a macro that expands into an expr, substituting in an intrinsic.
149// It passes that to for_base_types, which expands an impl for the types,
150// using the expanded expr in the function, and recurses with itself.
151//
152// tl;dr impls a set of ops::{Traits} for a set of types
153macro_rules! for_base_ops {
154    (
155        T = $types:tt;
156        type Lhs = Simd<T, N>;
157        type Rhs = Simd<T, N>;
158        type Output = $out:ident;
159        impl $op:ident::$call:ident
160            $inner:tt
161        $($rest:tt)*
162    ) => {
163        for_base_types! {
164            T = $types;
165            type Lhs = Simd<T, N>;
166            type Rhs = Simd<T, N>;
167            type Output = $out;
168            impl $op::$call
169                $inner
170        }
171        for_base_ops! {
172            T = $types;
173            type Lhs = Simd<T, N>;
174            type Rhs = Simd<T, N>;
175            type Output = $out;
176            $($rest)*
177        }
178    };
179    ($($done:tt)*) => {
180        // Done.
181    }
182}
183
184// Integers can always accept add, mul, sub, bitand, bitor, and bitxor.
185// For all of these operations, simd_* intrinsics apply wrapping logic.
186for_base_ops! {
187    T = (i8, i16, i32, i64, isize, u8, u16, u32, u64, usize);
188    type Lhs = Simd<T, N>;
189    type Rhs = Simd<T, N>;
190    type Output = Self;
191
192    impl Add::add {
193        unsafe_base { simd_add }
194    }
195
196    impl Mul::mul {
197        unsafe_base { simd_mul }
198    }
199
200    impl Sub::sub {
201        unsafe_base { simd_sub }
202    }
203
204    impl BitAnd::bitand {
205        unsafe_base { simd_and }
206    }
207
208    impl BitOr::bitor {
209        unsafe_base { simd_or }
210    }
211
212    impl BitXor::bitxor {
213        unsafe_base { simd_xor }
214    }
215
216    impl Div::div {
217        int_divrem_guard {
218            const PANIC_ZERO: &'static str = "attempt to divide by zero";
219            simd_div, /
220        }
221    }
222
223    impl Rem::rem {
224        int_divrem_guard {
225            const PANIC_ZERO: &'static str = "attempt to calculate the remainder with a divisor of zero";
226            simd_rem, %
227        }
228    }
229
230    // The only question is how to handle shifts >= <Int>::BITS?
231    // Our current solution uses wrapping logic.
232    impl Shl::shl {
233        wrap_bitshift { simd_shl }
234    }
235
236    impl Shr::shr {
237        wrap_bitshift {
238            // This automatically monomorphizes to lshr or ashr, depending,
239            // so it's fine to use it for both UInts and SInts.
240            simd_shr
241        }
242    }
243}
244
245// We don't need any special precautions here:
246// Floats always accept arithmetic ops, but may become NaN.
247for_base_ops! {
248    T = (f32, f64);
249    type Lhs = Simd<T, N>;
250    type Rhs = Simd<T, N>;
251    type Output = Self;
252
253    impl Add::add {
254        unsafe_base { simd_add }
255    }
256
257    impl Mul::mul {
258        unsafe_base { simd_mul }
259    }
260
261    impl Sub::sub {
262        unsafe_base { simd_sub }
263    }
264
265    impl Div::div {
266        unsafe_base { simd_div }
267    }
268
269    impl Rem::rem {
270        unsafe_base { simd_rem }
271    }
272}