core_simd/select.rs
1use crate::simd::{FixEndianness, Mask, MaskElement, Simd, SimdElement};
2
3/// Choose elements from two vectors using a mask.
4///
5/// For each element in the mask, choose the corresponding element from `true_values` if
6/// that element mask is true, and `false_values` if that element mask is false.
7///
8/// If the mask is `u64`, it's treated as a bitmask with the least significant bit
9/// corresponding to the first element.
10///
11/// # Examples
12///
13/// ## Selecting values from `Simd`
14/// ```
15/// # #![feature(portable_simd)]
16/// # #[cfg(feature = "as_crate")] use core_simd::simd;
17/// # #[cfg(not(feature = "as_crate"))] use core::simd;
18/// # use simd::{Simd, Mask, Select};
19/// let a = Simd::from_array([0, 1, 2, 3]);
20/// let b = Simd::from_array([4, 5, 6, 7]);
21/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
22/// let c = mask.select(a, b);
23/// assert_eq!(c.to_array(), [0, 5, 6, 3]);
24/// ```
25///
26/// ## Selecting values from `Mask`
27/// ```
28/// # #![feature(portable_simd)]
29/// # #[cfg(feature = "as_crate")] use core_simd::simd;
30/// # #[cfg(not(feature = "as_crate"))] use core::simd;
31/// # use simd::{Mask, Select};
32/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
33/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
34/// let mask = Mask::<i32, 4>::from_array([true, false, false, true]);
35/// let c = mask.select(a, b);
36/// assert_eq!(c.to_array(), [true, false, true, false]);
37/// ```
38///
39/// ## Selecting with a bitmask
40/// ```
41/// # #![feature(portable_simd)]
42/// # #[cfg(feature = "as_crate")] use core_simd::simd;
43/// # #[cfg(not(feature = "as_crate"))] use core::simd;
44/// # use simd::{Mask, Select};
45/// let a = Mask::<i32, 4>::from_array([true, true, false, false]);
46/// let b = Mask::<i32, 4>::from_array([false, false, true, true]);
47/// let mask = 0b1001;
48/// let c = mask.select(a, b);
49/// assert_eq!(c.to_array(), [true, false, true, false]);
50/// ```
51pub trait Select<T> {
52 /// Choose elements
53 fn select(self, true_values: T, false_values: T) -> T;
54}
55
56impl<T, U, const N: usize> Select<Simd<T, N>> for Mask<U, N>
57where
58 T: SimdElement,
59 U: MaskElement,
60{
61 #[inline]
62 fn select(self, true_values: Simd<T, N>, false_values: Simd<T, N>) -> Simd<T, N> {
63 // Safety:
64 // simd_as between masks is always safe (they're vectors of ints).
65 // simd_select uses a mask that matches the width and number of elements
66 unsafe {
67 let mask: Simd<T::Mask, N> = core::intrinsics::simd::simd_as(self.to_simd());
68 core::intrinsics::simd::simd_select(mask, true_values, false_values)
69 }
70 }
71}
72
73impl<T, const N: usize> Select<Simd<T, N>> for u64
74where
75 T: SimdElement,
76{
77 #[inline]
78 fn select(self, true_values: Simd<T, N>, false_values: Simd<T, N>) -> Simd<T, N> {
79 const {
80 assert!(N <= 64, "number of elements can't be greater than 64");
81 }
82
83 #[inline]
84 unsafe fn select_impl<T, U: FixEndianness, const M: usize, const N: usize>(
85 bitmask: U,
86 true_values: Simd<T, N>,
87 false_values: Simd<T, N>,
88 ) -> Simd<T, N>
89 where
90 T: SimdElement,
91 {
92 let default = true_values[0];
93 let true_values = true_values.resize::<M>(default);
94 let false_values = false_values.resize::<M>(default);
95
96 // LLVM assumes bit order should match endianness
97 let bitmask = bitmask.fix_endianness();
98
99 // Safety: the caller guarantees that the size of U matches M
100 let selected = unsafe {
101 core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values)
102 };
103
104 selected.resize::<N>(default)
105 }
106
107 // TODO modify simd_bitmask_select to truncate input, making this unnecessary
108 if N <= 8 {
109 let bitmask = self as u8;
110 // Safety: bitmask matches length
111 unsafe { select_impl::<T, u8, 8, N>(bitmask, true_values, false_values) }
112 } else if N <= 16 {
113 let bitmask = self as u16;
114 // Safety: bitmask matches length
115 unsafe { select_impl::<T, u16, 16, N>(bitmask, true_values, false_values) }
116 } else if N <= 32 {
117 let bitmask = self as u32;
118 // Safety: bitmask matches length
119 unsafe { select_impl::<T, u32, 32, N>(bitmask, true_values, false_values) }
120 } else {
121 let bitmask = self;
122 // Safety: bitmask matches length
123 unsafe { select_impl::<T, u64, 64, N>(bitmask, true_values, false_values) }
124 }
125 }
126}
127
128impl<T, U, const N: usize> Select<Mask<T, N>> for Mask<U, N>
129where
130 T: MaskElement,
131 U: MaskElement,
132{
133 #[inline]
134 fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
135 let selected: Simd<T, N> =
136 Select::select(self, true_values.to_simd(), false_values.to_simd());
137
138 // Safety: all values come from masks
139 unsafe { Mask::from_simd_unchecked(selected) }
140 }
141}
142
143impl<T, const N: usize> Select<Mask<T, N>> for u64
144where
145 T: MaskElement,
146{
147 #[inline]
148 fn select(self, true_values: Mask<T, N>, false_values: Mask<T, N>) -> Mask<T, N> {
149 let selected: Simd<T, N> =
150 Select::select(self, true_values.to_simd(), false_values.to_simd());
151
152 // Safety: all values come from masks
153 unsafe { Mask::from_simd_unchecked(selected) }
154 }
155}