Skip to main content

core_simd/
swizzle_dyn.rs

1use crate::simd::Simd;
2use core::mem;
3
4impl<const N: usize> Simd<u8, N> {
5    /// Swizzle a vector of bytes according to the index vector.
6    /// Indices within range select the appropriate byte.
7    /// Indices "out of bounds" instead select 0.
8    ///
9    /// Note that the current implementation is selected during build-time
10    /// of the standard library, so `cargo build -Zbuild-std` may be necessary
11    /// to unlock better performance, especially for larger vectors.
12    /// A planned compiler improvement will enable using `#[target_feature]` instead.
13    #[inline]
14    pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
15        #![allow(unused_imports, unused_unsafe)]
16        #[cfg(all(
17            target_arch = "arm",
18            target_feature = "v7",
19            target_feature = "neon",
20            target_endian = "little"
21        ))]
22        use core::arch::arm::{uint8x8_t, vtbl1_u8};
23        #[cfg(target_arch = "wasm32")]
24        use core::arch::wasm32 as wasm;
25        #[cfg(target_arch = "wasm64")]
26        use core::arch::wasm64 as wasm;
27        #[cfg(target_arch = "x86")]
28        use core::arch::x86;
29        #[cfg(target_arch = "x86_64")]
30        use core::arch::x86_64 as x86;
31        // SAFETY: Intrinsics covered by cfg
32        unsafe {
33            match N {
34                #[cfg(all(
35                    any(target_arch = "aarch64", target_arch = "arm64ec"),
36                    target_feature = "neon",
37                    target_endian = "little"
38                ))]
39                8 | 16 | 24 | 32 | 48 | 64 => aarch64_swizzle(self, idxs),
40                #[cfg(target_feature = "ssse3")]
41                16 => transize(x86::_mm_shuffle_epi8, self, zeroing_idxs(idxs)),
42                #[cfg(target_feature = "simd128")]
43                16 => transize(wasm::i8x16_swizzle, self, idxs),
44                #[cfg(all(
45                    target_arch = "arm",
46                    target_feature = "v7",
47                    target_feature = "neon",
48                    target_endian = "little"
49                ))]
50                16 => transize(armv7_neon_swizzle_u8x16, self, idxs),
51                #[cfg(all(target_arch = "loongarch64", target_feature = "lsx"))]
52                16 => transize(loong64_lsx_swizzle, self, idxs),
53                #[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
54                32 => transize(avx2_pshufb, self, idxs),
55                #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
56                32 => {
57                    // Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
58                    let swizzler = |bytes, idxs| {
59                        let mask = x86::_mm256_cmp_epu8_mask::<{ x86::_MM_CMPINT_LT }>(
60                            idxs,
61                            Simd::<u8, 32>::splat(N as u8).into(),
62                        );
63                        x86::_mm256_maskz_permutexvar_epi8(mask, idxs, bytes)
64                    };
65                    transize(swizzler, self, idxs)
66                }
67                #[cfg(all(target_arch = "loongarch64", target_feature = "lasx"))]
68                32 => transize(loong64_lasx_swizzle, self, idxs),
69                // Notable absence: avx512bw pshufb shuffle
70                #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
71                64 => {
72                    // Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
73                    let swizzler = |bytes, idxs| {
74                        let mask = x86::_mm512_cmp_epu8_mask::<{ x86::_MM_CMPINT_LT }>(
75                            idxs,
76                            Simd::<u8, 64>::splat(N as u8).into(),
77                        );
78                        x86::_mm512_maskz_permutexvar_epi8(mask, idxs, bytes)
79                    };
80                    transize(swizzler, self, idxs)
81                }
82                _ => {
83                    let mut array = [0; N];
84                    for (i, k) in idxs.to_array().into_iter().enumerate() {
85                        if (k as usize) < N {
86                            array[i] = self[k as usize];
87                        };
88                    }
89                    array.into()
90                }
91            }
92        }
93    }
94}
95
96/// armv7 neon supports swizzling `u8x16` by swizzling two u8x8 blocks
97/// with a u8x8x2 lookup table.
98///
99/// # Safety
100/// This requires armv7 neon to work
101#[cfg(all(
102    target_arch = "arm",
103    target_feature = "v7",
104    target_feature = "neon",
105    target_endian = "little"
106))]
107unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd<u8, 16>, idxs: Simd<u8, 16>) -> Simd<u8, 16> {
108    use core::arch::arm::{uint8x8x2_t, vcombine_u8, vget_high_u8, vget_low_u8, vtbl2_u8};
109    // SAFETY: Caller promised arm neon support
110    unsafe {
111        let bytes = uint8x8x2_t(vget_low_u8(bytes.into()), vget_high_u8(bytes.into()));
112        let lo = vtbl2_u8(bytes, vget_low_u8(idxs.into()));
113        let hi = vtbl2_u8(bytes, vget_high_u8(idxs.into()));
114        vcombine_u8(lo, hi).into()
115    }
116}
117
118/// AArch64 NEON supports swizzling 8, 16, 24, 32, 48 or 64 by stacking multiple TBL instructions.
119///
120/// # Safety
121/// This requires AArch64 NEON to work
122#[cfg(all(
123    any(target_arch = "aarch64", target_arch = "arm64ec"),
124    target_feature = "neon",
125    target_endian = "little"
126))]
127unsafe fn aarch64_swizzle<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N> {
128    use core::arch::aarch64::*;
129    use core::mem::transmute_copy;
130
131    // SAFETY: Caller promised AArch64 NEON support
132    unsafe {
133        match N {
134            8 => transmute_copy(&vtbl1_u8(transmute_copy(&bytes), transmute_copy(&idxs))),
135            16 => transmute_copy(&vqtbl1q_u8(transmute_copy(&bytes), transmute_copy(&idxs))),
136            24 => {
137                let bytes: uint8x8x3_t = transmute_copy(&bytes);
138                let idxs: uint8x8x3_t = transmute_copy(&idxs);
139
140                let ret0 = vtbl3_u8(bytes, idxs.0);
141                let ret1 = vtbl3_u8(bytes, idxs.1);
142                let ret2 = vtbl3_u8(bytes, idxs.2);
143
144                let ret = uint8x8x3_t(ret0, ret1, ret2);
145                transmute_copy(&ret)
146            }
147            32 => {
148                let bytes: uint8x16x2_t = transmute_copy(&bytes);
149                let idxs: uint8x16x2_t = transmute_copy(&idxs);
150
151                let ret0 = vqtbl2q_u8(bytes, idxs.0);
152                let ret1 = vqtbl2q_u8(bytes, idxs.1);
153
154                let ret = uint8x16x2_t(ret0, ret1);
155                transmute_copy(&ret)
156            }
157            48 => {
158                let bytes: uint8x16x3_t = transmute_copy(&bytes);
159                let idxs: uint8x16x3_t = transmute_copy(&idxs);
160
161                let ret0 = vqtbl3q_u8(bytes, idxs.0);
162                let ret1 = vqtbl3q_u8(bytes, idxs.1);
163                let ret2 = vqtbl3q_u8(bytes, idxs.2);
164
165                let ret = uint8x16x3_t(ret0, ret1, ret2);
166                transmute_copy(&ret)
167            }
168            64 => {
169                let bytes: uint8x16x4_t = transmute_copy(&bytes);
170                let idxs: uint8x16x4_t = transmute_copy(&idxs);
171
172                let ret0 = vqtbl4q_u8(bytes, idxs.0);
173                let ret1 = vqtbl4q_u8(bytes, idxs.1);
174                let ret2 = vqtbl4q_u8(bytes, idxs.2);
175                let ret3 = vqtbl4q_u8(bytes, idxs.3);
176
177                let ret = uint8x16x4_t(ret0, ret1, ret2, ret3);
178                transmute_copy(&ret)
179            }
180            _ => unreachable!(),
181        }
182    }
183}
184
185/// "vpshufb like it was meant to be" on AVX2
186///
187/// # Safety
188/// This requires AVX2 to work
189#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
190#[target_feature(enable = "avx2")]
191#[allow(unused)]
192#[inline]
193#[allow(clippy::let_and_return)]
194unsafe fn avx2_pshufb(bytes: Simd<u8, 32>, idxs: Simd<u8, 32>) -> Simd<u8, 32> {
195    use crate::simd::{Select, cmp::SimdPartialOrd};
196    #[cfg(target_arch = "x86")]
197    use core::arch::x86;
198    #[cfg(target_arch = "x86_64")]
199    use core::arch::x86_64 as x86;
200    use x86::_mm256_permute2x128_si256 as avx2_cross_shuffle;
201    use x86::_mm256_shuffle_epi8 as avx2_half_pshufb;
202    let mid = Simd::splat(16u8);
203    let high = mid + mid;
204    // SAFETY: Caller promised AVX2
205    unsafe {
206        // This is ordering sensitive, and LLVM will order these how you put them.
207        // Most AVX2 impls use ~5 "ports", and only 1 or 2 are capable of permutes.
208        // But the "compose" step will lower to ops that can also use at least 1 other port.
209        // So this tries to break up permutes so composition flows through "open" ports.
210        // Comparative benches should be done on multiple AVX2 CPUs before reordering this
211
212        let hihi = avx2_cross_shuffle::<0x11>(bytes.into(), bytes.into());
213        let hi_shuf = Simd::from(avx2_half_pshufb(
214            hihi,        // duplicate the vector's top half
215            idxs.into(), // so that using only 4 bits of an index still picks bytes 16-31
216        ));
217        // A zero-fill during the compose step gives the "all-Neon-like" OOB-is-0 semantics
218        let compose = idxs.simd_lt(high).select(hi_shuf, Simd::splat(0));
219        let lolo = avx2_cross_shuffle::<0x00>(bytes.into(), bytes.into());
220        let lo_shuf = Simd::from(avx2_half_pshufb(lolo, idxs.into()));
221        // Repeat, then pick indices < 16, overwriting indices 0-15 from previous compose step
222        let compose = idxs.simd_lt(mid).select(lo_shuf, compose);
223        compose
224    }
225}
226
227/// LoongArch64 LSX supports swizzling `u8x16`
228///
229/// # Safety
230/// This requires LoongArch LSX to work
231#[cfg(all(target_arch = "loongarch64", target_feature = "lsx"))]
232unsafe fn loong64_lsx_swizzle(bytes: Simd<u8, 16>, idxs: Simd<u8, 16>) -> Simd<u8, 16> {
233    use core::arch::loongarch64::{lsx_vand_v, lsx_vshuf_b, lsx_vslei_bu};
234    // SAFETY: Caller promised loongarch lsx support
235    unsafe {
236        let bytes = lsx_vshuf_b(bytes.into(), bytes.into(), idxs.into());
237        let mask = lsx_vslei_bu::<15>(idxs.into());
238        lsx_vand_v(bytes, mask).into()
239    }
240}
241
242/// LoongArch64 LASX supports swizzling `u8x32`
243///
244/// # Safety
245/// This requires LoongArch LASX to work
246#[cfg(all(target_arch = "loongarch64", target_feature = "lasx"))]
247unsafe fn loong64_lasx_swizzle(bytes: Simd<u8, 32>, idxs: Simd<u8, 32>) -> Simd<u8, 32> {
248    use core::arch::loongarch64::{lasx_xvand_v, lasx_xvpermi_q, lasx_xvshuf_b, lasx_xvslei_bu};
249    // SAFETY: Caller promised loongarch lasx support
250    unsafe {
251        let lolo = lasx_xvpermi_q::<0x00>(bytes.into(), bytes.into());
252        let hihi = lasx_xvpermi_q::<0x11>(bytes.into(), bytes.into());
253        let bytes = lasx_xvshuf_b(hihi, lolo, idxs.into());
254        let mask = lasx_xvslei_bu::<31>(idxs.into());
255        lasx_xvand_v(bytes, mask).into()
256    }
257}
258
259/// This sets up a call to an architecture-specific function, and in doing so
260/// it persuades rustc that everything is the correct size. Which it is.
261/// This would not be needed if one could convince Rust that, by matching on N,
262/// N is that value, and thus it would be valid to substitute e.g. 16.
263///
264/// # Safety
265/// The correctness of this function hinges on the sizes agreeing in actuality.
266#[allow(dead_code)]
267#[inline(always)]
268unsafe fn transize<T, const N: usize>(
269    f: unsafe fn(T, T) -> T,
270    a: Simd<u8, N>,
271    b: Simd<u8, N>,
272) -> Simd<u8, N> {
273    // SAFETY: Same obligation to use this function as to use mem::transmute_copy.
274    unsafe { mem::transmute_copy(&f(mem::transmute_copy(&a), mem::transmute_copy(&b))) }
275}
276
277/// Make indices that yield 0 for x86
278#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
279#[allow(unused)]
280#[inline(always)]
281fn zeroing_idxs<const N: usize>(idxs: Simd<u8, N>) -> Simd<u8, N> {
282    use crate::simd::{Select, cmp::SimdPartialOrd};
283    idxs.simd_lt(Simd::splat(N as u8))
284        .select(idxs, Simd::splat(u8::MAX))
285}