core_simd/
select.rs

1use crate::simd::{FixEndianness, Mask, MaskElement, Simd, SimdElement};
2
3/// Choose elements from two vectors using a mask.
4///
5/// For each element in the mask, choose the corresponding element from `true_values` if
6/// that element mask is true, and `false_values` if that element mask is false.
7///
8/// If the mask is `u64`, it's treated as a bitmask with the least significant bit
9/// corresponding to the first element.
10///
11/// # Examples
12///
13/// ## Selecting values from `Simd`
14/// ```
15/// # #![feature(portable_simd)]
16/// # #[cfg(feature = "as_crate")] use core_simd::simd;
17/// # #[cfg(not(feature = "as_crate"))] use core::simd;
18/// # use simd::{Simd, Mask, Select};
19/// let a = Simd::from_array([0, 1, 2, 3]);
20/// let b = Simd::from_array([4, 5, 6, 7]);
21/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
22/// let c = mask.select(a, b);
23/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
24/// ```
25///
26/// ## Selecting values from `Mask`
27/// ```
28/// # #![feature(portable_simd)]
29/// # #[cfg(feature = "as_crate")] use core_simd::simd;
30/// # #[cfg(not(feature = "as_crate"))] use core::simd;
31/// # use simd::{Mask, Select};
32/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
33/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
34/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
35/// let c = mask.select(a, b);
36/// assert_eq!(c.to_array(), [true, false, true, false]);
37/// ```
38///
39/// ## Selecting with a bitmask
40/// ```
41/// # #![feature(portable_simd)]
42/// # #[cfg(feature = "as_crate")] use core_simd::simd;
43/// # #[cfg(not(feature = "as_crate"))] use core::simd;
44/// # use simd::{Mask, Select};
45/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
46/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
47/// let mask = 0b1001;
48/// let c = mask.select(a, b);
49/// assert_eq!(c.to_array(), [true, false, true, false]);
50/// ```
51pub trait Select<T> {
52    /// Choose elements
53    fn select(self, true_values: T, false_values: T) -> T;
54}
55
56impl<T, U, const N: usize> Select<Simd<T, N>> for Mask<U, N>
57where
58    T: SimdElement,
59    U: MaskElement,
60{
61    #[inline]
62    fn select(self, true_values: Simd<T, N>, false_values: Simd<T, N>) -> Simd<T, N> {
63        // Safety:
64        // simd_as between masks is always safe (they're vectors of ints).
65        // simd_select uses a mask that matches the width and number of elements
66        unsafe {
67            let mask: Simd<T::Mask, N> = core::intrinsics::simd::simd_as(self.to_simd());
68            core::intrinsics::simd::simd_select(mask, true_values, false_values)
69        }
70    }
71}
72
73impl<T, const N: usize> Select<Simd<T, N>> for u64
74where
75    T: SimdElement,
76{
77    #[inline]
78    fn select(self, true_values: Simd<T, N>, false_values: Simd<T, N>) -> Simd<T, N> {
79        const {
80            assert!(N <= 64, "number of elements can't be greater than 64");
81        }
82
83        #[inline]
84        unsafe fn select_impl<T, U: FixEndianness, const M: usize, const N: usize>(
85            bitmask: U,
86            true_values: Simd<T, N>,
87            false_values: Simd<T, N>,
88        ) -> Simd<T, N>
89        where
90            T: SimdElement,
91        {
92            let default = true_values[0];
93            let true_values = true_values.resize::<M>(default);
94            let false_values = false_values.resize::<M>(default);
95
96            // LLVM assumes bit order should match endianness
97            let bitmask = bitmask.fix_endianness();
98
99            // Safety: the caller guarantees that the size of U matches M
100            let selected = unsafe {
101                core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values)
102            };
103
104            selected.resize::<N>(default)
105        }
106
107        // TODO modify simd_bitmask_select to truncate input, making this unnecessary
108        if N <= 8 {
109            let bitmask = self as u8;
110            // Safety: bitmask matches length
111            unsafe { select_impl::<T, u8, 8, N>(bitmask, true_values, false_values) }
112        } else if N <= 16 {
113            let bitmask = self as u16;
114            // Safety: bitmask matches length
115            unsafe { select_impl::<T, u16, 16, N>(bitmask, true_values, false_values) }
116        } else if N <= 32 {
117            let bitmask = self as u32;
118            // Safety: bitmask matches length
119            unsafe { select_impl::<T, u32, 32, N>(bitmask, true_values, false_values) }
120        } else {
121            let bitmask = self;
122            // Safety: bitmask matches length
123            unsafe { select_impl::<T, u64, 64, N>(bitmask, true_values, false_values) }
124        }
125    }
126}
127
128impl<T, U, const N: usize> Select<Mask<T, N>> for Mask<U, N>
129where
130    T: MaskElement,
131    U: MaskElement,
132{
133    #[inline]
134    fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
135        let selected: Simd<T, N> =
136            Select::select(self, true_values.to_simd(), false_values.to_simd());
137
138        // Safety: all values come from masks
139        unsafe { Mask::from_simd_unchecked(selected) }
140    }
141}
142
143impl<T, const N: usize> Select<Mask<T, N>> for u64
144where
145    T: MaskElement,
146{
147    #[inline]
148    fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
149        let selected: Simd<T, N> =
150            Select::select(self, true_values.to_simd(), false_values.to_simd());
151
152        // Safety: all values come from masks
153        unsafe { Mask::from_simd_unchecked(selected) }
154    }
155}