core_simd/
masks.rs

1//! Types and traits associated with masking elements of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10    fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14    { $($int:ty),* } => {
15        $(
16        impl FixEndianness for $int {
17            #[inline(always)]
18            fn fix_endianness(self) -> Self {
19                if cfg!(target_endian = "big") {
20                    <$int>::reverse_bits(self)
21                } else {
22                    self
23                }
24            }
25        }
26        )*
27    }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod sealed {
33    use super::*;
34
35    /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
36    /// from bleeding into the parent bounds.
37    ///
38    /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
39    /// prevent us from ever removing that bound, or from implementing `MaskElement` on
40    /// non-`PartialEq` types in the future.
41    pub trait Sealed {
42        fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43        where
44            Self: SimdElement;
45
46        fn eq(self, other: Self) -> bool;
47
48        fn to_usize(self) -> usize;
49        fn max_unsigned() -> u64;
50
51        type Unsigned: SimdElement;
52
53        const TRUE: Self;
54
55        const FALSE: Self;
56    }
57}
58use sealed::Sealed;
59
60/// Marker trait for types that may be used as SIMD mask elements.
61///
62/// # Safety
63/// Type must be a signed integer.
64pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
65
66macro_rules! impl_element {
67    { $ty:ty, $unsigned:ty } => {
68        impl Sealed for $ty {
69            #[inline]
70            fn valid<const N: usize>(value: Simd<Self, N>) -> bool
71            {
72                // We can't use `Simd` directly, because `Simd`'s functions call this function and
73                // we will end up with an infinite loop.
74                // Safety: `value` is an integer vector
75                unsafe {
76                    use core::intrinsics::simd;
77                    let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
78                    let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
79                    let valid: Simd<Self, N> = simd::simd_or(falses, trues);
80                    simd::simd_reduce_all(valid)
81                }
82            }
83
84            #[inline]
85            fn eq(self, other: Self) -> bool { self == other }
86
87            #[inline]
88            fn to_usize(self) -> usize {
89                self as usize
90            }
91
92            #[inline]
93            fn max_unsigned() -> u64 {
94                <$unsigned>::MAX as u64
95            }
96
97            type Unsigned = $unsigned;
98
99            const TRUE: Self = -1;
100            const FALSE: Self = 0;
101        }
102
103        // Safety: this is a valid mask element type
104        unsafe impl MaskElement for $ty {}
105    }
106}
107
108impl_element! { i8, u8 }
109impl_element! { i16, u16 }
110impl_element! { i32, u32 }
111impl_element! { i64, u64 }
112impl_element! { isize, usize }
113
114/// A SIMD vector mask for `N` elements of width specified by `Element`.
115///
116/// Masks represent boolean inclusion/exclusion on a per-element basis.
117///
118/// The layout of this type is unspecified, and may change between platforms
119/// and/or Rust versions, and code should not assume that it is equivalent to
120/// `[T; N]`.
121///
122/// `N` cannot be 0 and may be at most 64. This limit may be increased in
123/// the future.
124#[repr(transparent)]
125pub struct Mask<T, const N: usize>(Simd<T, N>)
126where
127    T: MaskElement;
128
129impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
130
131impl<T, const N: usize> Clone for Mask<T, N>
132where
133    T: MaskElement,
134{
135    #[inline]
136    fn clone(&self) -> Self {
137        *self
138    }
139}
140
141impl<T, const N: usize> Mask<T, N>
142where
143    T: MaskElement,
144{
145    /// Constructs a mask by setting all elements to the given value.
146    #[inline]
147    #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
148    pub const fn splat(value: bool) -> Self {
149        Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
150    }
151
152    /// Converts an array of bools to a SIMD mask.
153    #[inline]
154    pub fn from_array(array: [bool; N]) -> Self {
155        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
156        //     true:    0b_0000_0001
157        //     false:   0b_0000_0000
158        // Thus, an array of bools is also a valid array of bytes: [u8; N]
159        // This would be hypothetically valid as an "in-place" transmute,
160        // but these are "dependently-sized" types, so copy elision it is!
161        unsafe {
162            let bytes: [u8; N] = mem::transmute_copy(&array);
163            let bools: Simd<i8, N> =
164                core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
165            Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
166        }
167    }
168
169    /// Converts a SIMD mask to an array of bools.
170    #[inline]
171    pub fn to_array(self) -> [bool; N] {
172        // This follows mostly the same logic as from_array.
173        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
174        //     true:    0b_0000_0001
175        //     false:   0b_0000_0000
176        // Thus, an array of bools is also a valid array of bytes: [u8; N]
177        // Since our masks are equal to integers where all bits are set,
178        // we can simply convert them to i8s, and then bitand them by the
179        // bitpattern for Rust's "true" bool.
180        // This would be hypothetically valid as an "in-place" transmute,
181        // but these are "dependently-sized" types, so copy elision it is!
182        unsafe {
183            let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
184            bytes &= Simd::splat(1i8);
185            mem::transmute_copy(&bytes)
186        }
187    }
188
189    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
190    /// represents `true`.
191    ///
192    /// # Safety
193    /// All elements must be either 0 or -1.
194    #[inline]
195    #[must_use = "method returns a new mask and does not mutate the original value"]
196    pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
197        // Safety: the caller must confirm this invariant
198        unsafe {
199            core::intrinsics::assume(<T as Sealed>::valid(value));
200        }
201        Self(value)
202    }
203
204    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
205    /// represents `true`.
206    ///
207    /// # Panics
208    /// Panics if any element is not 0 or -1.
209    #[inline]
210    #[must_use = "method returns a new mask and does not mutate the original value"]
211    #[track_caller]
212    pub fn from_simd(value: Simd<T, N>) -> Self {
213        assert!(T::valid(value), "all values must be either 0 or -1",);
214        // Safety: the validity has been checked
215        unsafe { Self::from_simd_unchecked(value) }
216    }
217
218    /// Converts the mask to a vector of integers, where 0 represents `false` and -1
219    /// represents `true`.
220    #[inline]
221    #[must_use = "method returns a new vector and does not mutate the original value"]
222    pub fn to_simd(self) -> Simd<T, N> {
223        self.0
224    }
225
226    /// Converts the mask to a mask of any other element size.
227    #[inline]
228    #[must_use = "method returns a new mask and does not mutate the original value"]
229    pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
230        // Safety: mask elements are integers
231        unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
232    }
233
234    /// Tests the value of the specified element.
235    ///
236    /// # Safety
237    /// `index` must be less than `self.len()`.
238    #[inline]
239    #[must_use = "method returns a new bool and does not mutate the original value"]
240    pub unsafe fn test_unchecked(&self, index: usize) -> bool {
241        // Safety: the caller must confirm this invariant
242        unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
243    }
244
245    /// Tests the value of the specified element.
246    ///
247    /// # Panics
248    /// Panics if `index` is greater than or equal to the number of elements in the vector.
249    #[inline]
250    #[must_use = "method returns a new bool and does not mutate the original value"]
251    #[track_caller]
252    pub fn test(&self, index: usize) -> bool {
253        T::eq(self.0[index], T::TRUE)
254    }
255
256    /// Sets the value of the specified element.
257    ///
258    /// # Safety
259    /// `index` must be less than `self.len()`.
260    #[inline]
261    pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
262        // Safety: the caller must confirm this invariant
263        unsafe {
264            *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
265        }
266    }
267
268    /// Sets the value of the specified element.
269    ///
270    /// # Panics
271    /// Panics if `index` is greater than or equal to the number of elements in the vector.
272    #[inline]
273    #[track_caller]
274    pub fn set(&mut self, index: usize, value: bool) {
275        self.0[index] = if value { T::TRUE } else { T::FALSE }
276    }
277
278    /// Returns true if any element is set, or false otherwise.
279    #[inline]
280    #[must_use = "method returns a new bool and does not mutate the original value"]
281    pub fn any(self) -> bool {
282        // Safety: `self` is a mask vector
283        unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
284    }
285
286    /// Returns true if all elements are set, or false otherwise.
287    #[inline]
288    #[must_use = "method returns a new bool and does not mutate the original value"]
289    pub fn all(self) -> bool {
290        // Safety: `self` is a mask vector
291        unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
292    }
293
294    /// Creates a bitmask from a mask.
295    ///
296    /// Each bit is set if the corresponding element in the mask is `true`.
297    #[inline]
298    #[must_use = "method returns a new integer and does not mutate the original value"]
299    pub fn to_bitmask(self) -> u64 {
300        const {
301            assert!(N <= 64, "number of elements can't be greater than 64");
302        }
303
304        #[inline]
305        unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
306            mask: Mask<T, N>,
307        ) -> U
308        where
309            T: MaskElement,
310        {
311            let resized = mask.resize::<M>(false);
312
313            // Safety: `resized` is an integer vector with length M, which must match T
314            let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
315
316            // LLVM assumes bit order should match endianness
317            bitmask.fix_endianness()
318        }
319
320        // TODO modify simd_bitmask to zero-extend output, making this unnecessary
321        if N <= 8 {
322            // Safety: bitmask matches length
323            unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
324        } else if N <= 16 {
325            // Safety: bitmask matches length
326            unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
327        } else if N <= 32 {
328            // Safety: bitmask matches length
329            unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
330        } else {
331            // Safety: bitmask matches length
332            unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
333        }
334    }
335
336    /// Creates a mask from a bitmask.
337    ///
338    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
339    /// If the mask contains more than 64 elements, the remainder are set to `false`.
340    #[inline]
341    #[must_use = "method returns a new mask and does not mutate the original value"]
342    pub fn from_bitmask(bitmask: u64) -> Self {
343        Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
344    }
345
346    /// Finds the index of the first set element.
347    ///
348    /// ```
349    /// # #![feature(portable_simd)]
350    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
351    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
352    /// # use simd::mask32x8;
353    /// assert_eq!(mask32x8::splat(false).first_set(), None);
354    /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
355    ///
356    /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
357    /// assert_eq!(mask.first_set(), Some(1));
358    /// ```
359    #[inline]
360    #[must_use = "method returns the index and does not mutate the original value"]
361    pub fn first_set(self) -> Option<usize> {
362        // If bitmasks are efficient, using them is better
363        if cfg!(target_feature = "sse") && N <= 64 {
364            let tz = self.to_bitmask().trailing_zeros();
365            return if tz == 64 { None } else { Some(tz as usize) };
366        }
367
368        // To find the first set index:
369        // * create a vector 0..N
370        // * replace unset mask elements in that vector with -1
371        // * perform _unsigned_ reduce-min
372        // * check if the result is -1 or an index
373
374        let index = Simd::from_array(
375            const {
376                let mut index = [0; N];
377                let mut i = 0;
378                while i < N {
379                    index[i] = i;
380                    i += 1;
381                }
382                index
383            },
384        );
385
386        // Safety: the input and output are integer vectors
387        let index: Simd<T, N> = unsafe { core::intrinsics::simd::simd_cast(index) };
388
389        let masked_index = self.select(index, Self::splat(true).to_simd());
390
391        // Safety: the input and output are integer vectors
392        let masked_index: Simd<T::Unsigned, N> =
393            unsafe { core::intrinsics::simd::simd_cast(masked_index) };
394
395        // Safety: the input is an integer vector
396        let min_index: T::Unsigned =
397            unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
398
399        // Safety: the return value is the unsigned version of T
400        let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
401
402        if min_index.eq(T::TRUE) {
403            None
404        } else {
405            Some(min_index.to_usize())
406        }
407    }
408}
409
410// vector/array conversion
411impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
412where
413    T: MaskElement,
414{
415    #[inline]
416    fn from(array: [bool; N]) -> Self {
417        Self::from_array(array)
418    }
419}
420
421impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
422where
423    T: MaskElement,
424{
425    #[inline]
426    fn from(vector: Mask<T, N>) -> Self {
427        vector.to_array()
428    }
429}
430
431impl<T, const N: usize> Default for Mask<T, N>
432where
433    T: MaskElement,
434{
435    #[inline]
436    fn default() -> Self {
437        Self::splat(false)
438    }
439}
440
441impl<T, const N: usize> PartialEq for Mask<T, N>
442where
443    T: MaskElement + PartialEq,
444{
445    #[inline]
446    fn eq(&self, other: &Self) -> bool {
447        self.0 == other.0
448    }
449}
450
451impl<T, const N: usize> PartialOrd for Mask<T, N>
452where
453    T: MaskElement + PartialOrd,
454{
455    #[inline]
456    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
457        self.0.partial_cmp(&other.0)
458    }
459}
460
461impl<T, const N: usize> fmt::Debug for Mask<T, N>
462where
463    T: MaskElement + fmt::Debug,
464{
465    #[inline]
466    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467        f.debug_list()
468            .entries((0..N).map(|i| self.test(i)))
469            .finish()
470    }
471}
472
473impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
474where
475    T: MaskElement,
476{
477    type Output = Self;
478    #[inline]
479    fn bitand(self, rhs: Self) -> Self {
480        // Safety: `self` is an integer vector
481        unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
482    }
483}
484
485impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
486where
487    T: MaskElement,
488{
489    type Output = Self;
490    #[inline]
491    fn bitand(self, rhs: bool) -> Self {
492        self & Self::splat(rhs)
493    }
494}
495
496impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
497where
498    T: MaskElement,
499{
500    type Output = Mask<T, N>;
501    #[inline]
502    fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
503        Mask::splat(self) & rhs
504    }
505}
506
507impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
508where
509    T: MaskElement,
510{
511    type Output = Self;
512    #[inline]
513    fn bitor(self, rhs: Self) -> Self {
514        // Safety: `self` is an integer vector
515        unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
516    }
517}
518
519impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
520where
521    T: MaskElement,
522{
523    type Output = Self;
524    #[inline]
525    fn bitor(self, rhs: bool) -> Self {
526        self | Self::splat(rhs)
527    }
528}
529
530impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
531where
532    T: MaskElement,
533{
534    type Output = Mask<T, N>;
535    #[inline]
536    fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
537        Mask::splat(self) | rhs
538    }
539}
540
541impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
542where
543    T: MaskElement,
544{
545    type Output = Self;
546    #[inline]
547    fn bitxor(self, rhs: Self) -> Self::Output {
548        // Safety: `self` is an integer vector
549        unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
550    }
551}
552
553impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
554where
555    T: MaskElement,
556{
557    type Output = Self;
558    #[inline]
559    fn bitxor(self, rhs: bool) -> Self::Output {
560        self ^ Self::splat(rhs)
561    }
562}
563
564impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
565where
566    T: MaskElement,
567{
568    type Output = Mask<T, N>;
569    #[inline]
570    fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
571        Mask::splat(self) ^ rhs
572    }
573}
574
575impl<T, const N: usize> core::ops::Not for Mask<T, N>
576where
577    T: MaskElement,
578{
579    type Output = Mask<T, N>;
580    #[inline]
581    fn not(self) -> Self::Output {
582        Self::splat(true) ^ self
583    }
584}
585
586impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
587where
588    T: MaskElement,
589{
590    #[inline]
591    fn bitand_assign(&mut self, rhs: Self) {
592        *self = *self & rhs;
593    }
594}
595
596impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
597where
598    T: MaskElement,
599{
600    #[inline]
601    fn bitand_assign(&mut self, rhs: bool) {
602        *self &= Self::splat(rhs);
603    }
604}
605
606impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
607where
608    T: MaskElement,
609{
610    #[inline]
611    fn bitor_assign(&mut self, rhs: Self) {
612        *self = *self | rhs;
613    }
614}
615
616impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
617where
618    T: MaskElement,
619{
620    #[inline]
621    fn bitor_assign(&mut self, rhs: bool) {
622        *self |= Self::splat(rhs);
623    }
624}
625
626impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
627where
628    T: MaskElement,
629{
630    #[inline]
631    fn bitxor_assign(&mut self, rhs: Self) {
632        *self = *self ^ rhs;
633    }
634}
635
636impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
637where
638    T: MaskElement,
639{
640    #[inline]
641    fn bitxor_assign(&mut self, rhs: bool) {
642        *self ^= Self::splat(rhs);
643    }
644}
645
646macro_rules! impl_from {
647    { $from:ty  => $($to:ty),* } => {
648        $(
649        impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
650        {
651            #[inline]
652            fn from(value: Mask<$from, N>) -> Self {
653                value.cast()
654            }
655        }
656        )*
657    }
658}
659impl_from! { i8 => i16, i32, i64, isize }
660impl_from! { i16 => i32, i64, isize, i8 }
661impl_from! { i32 => i64, isize, i8, i16 }
662impl_from! { i64 => isize, i8, i16, i32 }
663impl_from! { isize => i8, i16, i32, i64 }