core_simd/
swizzle_dyn.rs

1use crate::simd::Simd;
2use core::mem;
3
4impl<const N: usize> Simd<u8, N> {
5    /// Swizzle a vector of bytes according to the index vector.
6    /// Indices within range select the appropriate byte.
7    /// Indices "out of bounds" instead select 0.
8    ///
9    /// Note that the current implementation is selected during build-time
10    /// of the standard library, so `cargo build -Zbuild-std` may be necessary
11    /// to unlock better performance, especially for larger vectors.
12    /// A planned compiler improvement will enable using `#[target_feature]` instead.
13    #[inline]
14    pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
15        #![allow(unused_imports, unused_unsafe)]
16        #[cfg(all(
17            any(target_arch = "aarch64", target_arch = "arm64ec"),
18            target_endian = "little"
19        ))]
20        use core::arch::aarch64::{uint8x8_t, vqtbl1q_u8, vtbl1_u8};
21        #[cfg(all(
22            target_arch = "arm",
23            target_feature = "v7",
24            target_feature = "neon",
25            target_endian = "little"
26        ))]
27        use core::arch::arm::{uint8x8_t, vtbl1_u8};
28        #[cfg(target_arch = "wasm32")]
29        use core::arch::wasm32 as wasm;
30        #[cfg(target_arch = "wasm64")]
31        use core::arch::wasm64 as wasm;
32        #[cfg(target_arch = "x86")]
33        use core::arch::x86;
34        #[cfg(target_arch = "x86_64")]
35        use core::arch::x86_64 as x86;
36        // SAFETY: Intrinsics covered by cfg
37        unsafe {
38            match N {
39                #[cfg(all(
40                    any(
41                        target_arch = "aarch64",
42                        target_arch = "arm64ec",
43                        all(target_arch = "arm", target_feature = "v7")
44                    ),
45                    target_feature = "neon",
46                    target_endian = "little"
47                ))]
48                8 => transize(vtbl1_u8, self, idxs),
49                #[cfg(target_feature = "ssse3")]
50                16 => transize(x86::_mm_shuffle_epi8, self, zeroing_idxs(idxs)),
51                #[cfg(target_feature = "simd128")]
52                16 => transize(wasm::i8x16_swizzle, self, idxs),
53                #[cfg(all(
54                    any(target_arch = "aarch64", target_arch = "arm64ec"),
55                    target_feature = "neon",
56                    target_endian = "little"
57                ))]
58                16 => transize(vqtbl1q_u8, self, idxs),
59                #[cfg(all(
60                    target_arch = "arm",
61                    target_feature = "v7",
62                    target_feature = "neon",
63                    target_endian = "little"
64                ))]
65                16 => transize(armv7_neon_swizzle_u8x16, self, idxs),
66                #[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
67                32 => transize(avx2_pshufb, self, idxs),
68                #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
69                32 => {
70                    // Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
71                    let swizzler = |bytes, idxs| {
72                        let mask = x86::_mm256_cmp_epu8_mask::<{ x86::_MM_CMPINT_LT }>(
73                            idxs,
74                            Simd::<u8, 32>::splat(N as u8).into(),
75                        );
76                        x86::_mm256_maskz_permutexvar_epi8(mask, idxs, bytes)
77                    };
78                    transize(swizzler, self, idxs)
79                }
80                // Notable absence: avx512bw pshufb shuffle
81                #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
82                64 => {
83                    // Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
84                    let swizzler = |bytes, idxs| {
85                        let mask = x86::_mm512_cmp_epu8_mask::<{ x86::_MM_CMPINT_LT }>(
86                            idxs,
87                            Simd::<u8, 64>::splat(N as u8).into(),
88                        );
89                        x86::_mm512_maskz_permutexvar_epi8(mask, idxs, bytes)
90                    };
91                    transize(swizzler, self, idxs)
92                }
93                _ => {
94                    let mut array = [0; N];
95                    for (i, k) in idxs.to_array().into_iter().enumerate() {
96                        if (k as usize) < N {
97                            array[i] = self[k as usize];
98                        };
99                    }
100                    array.into()
101                }
102            }
103        }
104    }
105}
106
107/// armv7 neon supports swizzling `u8x16` by swizzling two u8x8 blocks
108/// with a u8x8x2 lookup table.
109///
110/// # Safety
111/// This requires armv7 neon to work
112#[cfg(all(
113    target_arch = "arm",
114    target_feature = "v7",
115    target_feature = "neon",
116    target_endian = "little"
117))]
118unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd<u8, 16>, idxs: Simd<u8, 16>) -> Simd<u8, 16> {
119    use core::arch::arm::{uint8x8x2_t, vcombine_u8, vget_high_u8, vget_low_u8, vtbl2_u8};
120    // SAFETY: Caller promised arm neon support
121    unsafe {
122        let bytes = uint8x8x2_t(vget_low_u8(bytes.into()), vget_high_u8(bytes.into()));
123        let lo = vtbl2_u8(bytes, vget_low_u8(idxs.into()));
124        let hi = vtbl2_u8(bytes, vget_high_u8(idxs.into()));
125        vcombine_u8(lo, hi).into()
126    }
127}
128
129/// "vpshufb like it was meant to be" on AVX2
130///
131/// # Safety
132/// This requires AVX2 to work
133#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
134#[target_feature(enable = "avx2")]
135#[allow(unused)]
136#[inline]
137#[allow(clippy::let_and_return)]
138unsafe fn avx2_pshufb(bytes: Simd<u8, 32>, idxs: Simd<u8, 32>) -> Simd<u8, 32> {
139    use crate::simd::{Select, cmp::SimdPartialOrd};
140    #[cfg(target_arch = "x86")]
141    use core::arch::x86;
142    #[cfg(target_arch = "x86_64")]
143    use core::arch::x86_64 as x86;
144    use x86::_mm256_permute2x128_si256 as avx2_cross_shuffle;
145    use x86::_mm256_shuffle_epi8 as avx2_half_pshufb;
146    let mid = Simd::splat(16u8);
147    let high = mid + mid;
148    // SAFETY: Caller promised AVX2
149    unsafe {
150        // This is ordering sensitive, and LLVM will order these how you put them.
151        // Most AVX2 impls use ~5 "ports", and only 1 or 2 are capable of permutes.
152        // But the "compose" step will lower to ops that can also use at least 1 other port.
153        // So this tries to break up permutes so composition flows through "open" ports.
154        // Comparative benches should be done on multiple AVX2 CPUs before reordering this
155
156        let hihi = avx2_cross_shuffle::<0x11>(bytes.into(), bytes.into());
157        let hi_shuf = Simd::from(avx2_half_pshufb(
158            hihi,        // duplicate the vector's top half
159            idxs.into(), // so that using only 4 bits of an index still picks bytes 16-31
160        ));
161        // A zero-fill during the compose step gives the "all-Neon-like" OOB-is-0 semantics
162        let compose = idxs.simd_lt(high).select(hi_shuf, Simd::splat(0));
163        let lolo = avx2_cross_shuffle::<0x00>(bytes.into(), bytes.into());
164        let lo_shuf = Simd::from(avx2_half_pshufb(lolo, idxs.into()));
165        // Repeat, then pick indices < 16, overwriting indices 0-15 from previous compose step
166        let compose = idxs.simd_lt(mid).select(lo_shuf, compose);
167        compose
168    }
169}
170
171/// This sets up a call to an architecture-specific function, and in doing so
172/// it persuades rustc that everything is the correct size. Which it is.
173/// This would not be needed if one could convince Rust that, by matching on N,
174/// N is that value, and thus it would be valid to substitute e.g. 16.
175///
176/// # Safety
177/// The correctness of this function hinges on the sizes agreeing in actuality.
178#[allow(dead_code)]
179#[inline(always)]
180unsafe fn transize<T, const N: usize>(
181    f: unsafe fn(T, T) -> T,
182    a: Simd<u8, N>,
183    b: Simd<u8, N>,
184) -> Simd<u8, N> {
185    // SAFETY: Same obligation to use this function as to use mem::transmute_copy.
186    unsafe { mem::transmute_copy(&f(mem::transmute_copy(&a), mem::transmute_copy(&b))) }
187}
188
189/// Make indices that yield 0 for x86
190#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
191#[allow(unused)]
192#[inline(always)]
193fn zeroing_idxs<const N: usize>(idxs: Simd<u8, N>) -> Simd<u8, N> {
194    use crate::simd::{Select, cmp::SimdPartialOrd};
195    idxs.simd_lt(Simd::splat(N as u8))
196        .select(idxs, Simd::splat(u8::MAX))
197}