Skip to main content

core_simd/
masks.rs

1//! Types and traits associated with masking elements of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10    fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14    { $($int:ty),* } => {
15        $(
16        impl FixEndianness for $int {
17            #[inline(always)]
18            fn fix_endianness(self) -> Self {
19                if cfg!(target_endian = "big") {
20                    <$int>::reverse_bits(self)
21                } else {
22                    self
23                }
24            }
25        }
26        )*
27    }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod sealed {
33    use super::*;
34
35    /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
36    /// from bleeding into the parent bounds.
37    ///
38    /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
39    /// prevent us from ever removing that bound, or from implementing `MaskElement` on
40    /// non-`PartialEq` types in the future.
41    pub trait Sealed {
42        fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43        where
44            Self: SimdElement;
45
46        fn eq(self, other: Self) -> bool;
47
48        fn to_usize(self) -> usize;
49        fn max_unsigned() -> u64;
50
51        type Unsigned: SimdElement;
52
53        const TRUE: Self;
54
55        const FALSE: Self;
56    }
57}
58use sealed::Sealed;
59
60/// Marker trait for types that may be used as SIMD mask elements.
61///
62/// # Safety
63/// Type must be a signed integer.
64pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
65
66macro_rules! impl_element {
67    { $ty:ty, $unsigned:ty } => {
68        impl Sealed for $ty {
69            #[inline]
70            fn valid<const N: usize>(value: Simd<Self, N>) -> bool
71            {
72                // We can't use `Simd` directly, because `Simd`'s functions call this function and
73                // we will end up with an infinite loop.
74                // Safety: `value` is an integer vector
75                unsafe {
76                    use core::intrinsics::simd;
77                    let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
78                    let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
79                    let valid: Simd<Self, N> = simd::simd_or(falses, trues);
80                    simd::simd_reduce_all(valid)
81                }
82            }
83
84            #[inline]
85            fn eq(self, other: Self) -> bool { self == other }
86
87            #[inline]
88            fn to_usize(self) -> usize {
89                self as usize
90            }
91
92            #[inline]
93            fn max_unsigned() -> u64 {
94                <$unsigned>::MAX as u64
95            }
96
97            type Unsigned = $unsigned;
98
99            const TRUE: Self = -1;
100            const FALSE: Self = 0;
101        }
102
103        // Safety: this is a valid mask element type
104        unsafe impl MaskElement for $ty {}
105    }
106}
107
108impl_element! { i8, u8 }
109impl_element! { i16, u16 }
110impl_element! { i32, u32 }
111impl_element! { i64, u64 }
112impl_element! { isize, usize }
113
114/// A SIMD vector mask for `N` elements of width specified by `Element`.
115///
116/// Masks represent boolean inclusion/exclusion on a per-element basis.
117///
118/// The layout of this type is unspecified, and may change between platforms
119/// and/or Rust versions, and code should not assume that it is equivalent to
120/// `[T; N]`.
121///
122/// `N` cannot be 0 and may be at most 64. This limit may be increased in
123/// the future.
124#[repr(transparent)]
125pub struct Mask<T, const N: usize>(Simd<T, N>)
126where
127    T: MaskElement;
128
129impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
130
131impl<T, const N: usize> Clone for Mask<T, N>
132where
133    T: MaskElement,
134{
135    #[inline]
136    fn clone(&self) -> Self {
137        *self
138    }
139}
140
141impl<T, const N: usize> Mask<T, N>
142where
143    T: MaskElement,
144{
145    /// Constructs a mask by setting all elements to the given value.
146    #[inline]
147    #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
148    pub const fn splat(value: bool) -> Self {
149        Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
150    }
151
152    /// Converts an array of bools to a SIMD mask.
153    #[inline]
154    pub fn from_array(array: [bool; N]) -> Self {
155        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
156        //     true:    0b_0000_0001
157        //     false:   0b_0000_0000
158        // Thus, an array of bools is also a valid array of bytes: [u8; N]
159        // This would be hypothetically valid as an "in-place" transmute,
160        // but these are "dependently-sized" types, so copy elision it is!
161        unsafe {
162            let bytes: [u8; N] = mem::transmute_copy(&array);
163            let bools: Simd<i8, N> =
164                core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
165            Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
166        }
167    }
168
169    /// Converts a SIMD mask to an array of bools.
170    #[inline]
171    pub fn to_array(self) -> [bool; N] {
172        // This follows mostly the same logic as from_array.
173        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
174        //     true:    0b_0000_0001
175        //     false:   0b_0000_0000
176        // Thus, an array of bools is also a valid array of bytes: [u8; N]
177        // Since our masks are equal to integers where all bits are set,
178        // we can simply convert them to i8s, and then bitand them by the
179        // bitpattern for Rust's "true" bool.
180        // This would be hypothetically valid as an "in-place" transmute,
181        // but these are "dependently-sized" types, so copy elision it is!
182        unsafe {
183            let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
184            bytes &= Simd::splat(1i8);
185            mem::transmute_copy(&bytes)
186        }
187    }
188
189    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
190    /// represents `true`.
191    ///
192    /// # Safety
193    /// All elements must be either 0 or -1.
194    #[inline]
195    #[must_use = "method returns a new mask and does not mutate the original value"]
196    pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
197        // Safety: the caller must confirm this invariant
198        unsafe {
199            core::intrinsics::assume(<T as Sealed>::valid(value));
200        }
201        Self(value)
202    }
203
204    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
205    /// represents `true`.
206    ///
207    /// # Panics
208    /// Panics if any element is not 0 or -1.
209    #[inline]
210    #[must_use = "method returns a new mask and does not mutate the original value"]
211    #[track_caller]
212    pub fn from_simd(value: Simd<T, N>) -> Self {
213        assert!(T::valid(value), "all values must be either 0 or -1",);
214        // Safety: the validity has been checked
215        unsafe { Self::from_simd_unchecked(value) }
216    }
217
218    /// Converts the mask to a vector of integers, where 0 represents `false` and -1
219    /// represents `true`.
220    #[inline]
221    #[must_use = "method returns a new vector and does not mutate the original value"]
222    pub fn to_simd(self) -> Simd<T, N> {
223        self.0
224    }
225
226    /// Converts the mask to a mask of any other element size.
227    #[inline]
228    #[must_use = "method returns a new mask and does not mutate the original value"]
229    pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
230        // Safety: mask elements are integers
231        unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
232    }
233
234    /// Tests the value of the specified element.
235    ///
236    /// # Safety
237    /// `index` must be less than `self.len()`.
238    #[inline]
239    #[must_use = "method returns a new bool and does not mutate the original value"]
240    pub unsafe fn test_unchecked(&self, index: usize) -> bool {
241        // Safety: the caller must confirm this invariant
242        unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
243    }
244
245    /// Tests the value of the specified element.
246    ///
247    /// # Panics
248    /// Panics if `index` is greater than or equal to the number of elements in the vector.
249    #[inline]
250    #[must_use = "method returns a new bool and does not mutate the original value"]
251    #[track_caller]
252    pub fn test(&self, index: usize) -> bool {
253        T::eq(self.0[index], T::TRUE)
254    }
255
256    /// Sets the value of the specified element.
257    ///
258    /// # Safety
259    /// `index` must be less than `self.len()`.
260    #[inline]
261    pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
262        // Safety: the caller must confirm this invariant
263        unsafe {
264            *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
265        }
266    }
267
268    /// Sets the value of the specified element.
269    ///
270    /// # Panics
271    /// Panics if `index` is greater than or equal to the number of elements in the vector.
272    #[inline]
273    #[track_caller]
274    pub fn set(&mut self, index: usize, value: bool) {
275        self.0[index] = if value { T::TRUE } else { T::FALSE }
276    }
277
278    /// Returns true if any element is set, or false otherwise.
279    #[inline]
280    #[must_use = "method returns a new bool and does not mutate the original value"]
281    pub fn any(self) -> bool {
282        // Safety: `self` is a mask vector
283        unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
284    }
285
286    /// Returns true if all elements are set, or false otherwise.
287    #[inline]
288    #[must_use = "method returns a new bool and does not mutate the original value"]
289    pub fn all(self) -> bool {
290        // Safety: `self` is a mask vector
291        unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
292    }
293
294    /// Creates a bitmask from a mask.
295    ///
296    /// Each bit is set if the corresponding element in the mask is `true`.
297    #[inline]
298    #[must_use = "method returns a new integer and does not mutate the original value"]
299    pub fn to_bitmask(self) -> u64 {
300        const {
301            assert!(N <= 64, "number of elements can't be greater than 64");
302        }
303
304        #[inline]
305        unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
306            mask: Mask<T, N>,
307        ) -> U
308        where
309            T: MaskElement,
310        {
311            let resized = mask.resize::<M>(false);
312
313            // Safety: `resized` is an integer vector with length M, which must match T
314            let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
315
316            // LLVM assumes bit order should match endianness
317            bitmask.fix_endianness()
318        }
319
320        // TODO modify simd_bitmask to zero-extend output, making this unnecessary
321        if N <= 8 {
322            // Safety: bitmask matches length
323            unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
324        } else if N <= 16 {
325            // Safety: bitmask matches length
326            unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
327        } else if N <= 32 {
328            // Safety: bitmask matches length
329            unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
330        } else {
331            // Safety: bitmask matches length
332            unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
333        }
334    }
335
336    /// Creates a mask from a bitmask.
337    ///
338    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
339    /// If the mask contains more than 64 elements, the remainder are set to `false`.
340    #[inline]
341    #[must_use = "method returns a new mask and does not mutate the original value"]
342    pub fn from_bitmask(bitmask: u64) -> Self {
343        Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
344    }
345
346    /// Finds the index of the first set element.
347    ///
348    /// ```
349    /// # #![feature(portable_simd)]
350    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
351    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
352    /// # use simd::mask32x8;
353    /// assert_eq!(mask32x8::splat(false).first_set(), None);
354    /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
355    ///
356    /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
357    /// assert_eq!(mask.first_set(), Some(1));
358    /// ```
359    #[inline]
360    #[must_use = "method returns the index and does not mutate the original value"]
361    pub fn first_set(self) -> Option<usize> {
362        // If bitmasks are efficient, using them is better
363        if cfg!(target_feature = "sse") && N <= 64 {
364            let tz = self.to_bitmask().trailing_zeros();
365            return if tz == 64 { None } else { Some(tz as usize) };
366        }
367
368        // To find the first set index:
369        // * create a vector 0..N
370        // * replace unset mask elements in that vector with -1
371        // * perform _unsigned_ reduce-min
372        // * check if the result is -1 or an index
373
374        let index: Simd<T, N> = const {
375            let mut index = [0; N];
376            let mut i = 0;
377            while i < N {
378                index[i] = i;
379                i += 1;
380            }
381            // Safety: the input and output are integer vectors
382            unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
383        };
384
385        // Safety: the input and output are integer vectors
386        let masked_index: Simd<T, N> =
387            unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
388
389        // Safety: the input and output are integer vectors
390        let masked_index: Simd<T::Unsigned, N> =
391            unsafe { core::intrinsics::simd::simd_cast(masked_index) };
392
393        // Safety: the input is an integer vector
394        let min_index: T::Unsigned =
395            unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
396
397        // Safety: the return value is the unsigned version of T
398        let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
399
400        if min_index.eq(T::TRUE) {
401            None
402        } else {
403            Some(min_index.to_usize())
404        }
405    }
406}
407
408// vector/array conversion
409impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
410where
411    T: MaskElement,
412{
413    #[inline]
414    fn from(array: [bool; N]) -> Self {
415        Self::from_array(array)
416    }
417}
418
419impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
420where
421    T: MaskElement,
422{
423    #[inline]
424    fn from(vector: Mask<T, N>) -> Self {
425        vector.to_array()
426    }
427}
428
429impl<T, const N: usize> Default for Mask<T, N>
430where
431    T: MaskElement,
432{
433    #[inline]
434    fn default() -> Self {
435        Self::splat(false)
436    }
437}
438
439impl<T, const N: usize> PartialEq for Mask<T, N>
440where
441    T: MaskElement + PartialEq,
442{
443    #[inline]
444    fn eq(&self, other: &Self) -> bool {
445        self.0 == other.0
446    }
447}
448
449impl<T, const N: usize> PartialOrd for Mask<T, N>
450where
451    T: MaskElement + PartialOrd,
452{
453    #[inline]
454    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
455        self.0.partial_cmp(&other.0)
456    }
457}
458
459impl<T, const N: usize> fmt::Debug for Mask<T, N>
460where
461    T: MaskElement + fmt::Debug,
462{
463    #[inline]
464    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465        f.debug_list()
466            .entries((0..N).map(|i| self.test(i)))
467            .finish()
468    }
469}
470
471impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
472where
473    T: MaskElement,
474{
475    type Output = Self;
476    #[inline]
477    fn bitand(self, rhs: Self) -> Self {
478        // Safety: `self` is an integer vector
479        unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
480    }
481}
482
483impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
484where
485    T: MaskElement,
486{
487    type Output = Self;
488    #[inline]
489    fn bitand(self, rhs: bool) -> Self {
490        self & Self::splat(rhs)
491    }
492}
493
494impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
495where
496    T: MaskElement,
497{
498    type Output = Mask<T, N>;
499    #[inline]
500    fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
501        Mask::splat(self) & rhs
502    }
503}
504
505impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
506where
507    T: MaskElement,
508{
509    type Output = Self;
510    #[inline]
511    fn bitor(self, rhs: Self) -> Self {
512        // Safety: `self` is an integer vector
513        unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
514    }
515}
516
517impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
518where
519    T: MaskElement,
520{
521    type Output = Self;
522    #[inline]
523    fn bitor(self, rhs: bool) -> Self {
524        self | Self::splat(rhs)
525    }
526}
527
528impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
529where
530    T: MaskElement,
531{
532    type Output = Mask<T, N>;
533    #[inline]
534    fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
535        Mask::splat(self) | rhs
536    }
537}
538
539impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
540where
541    T: MaskElement,
542{
543    type Output = Self;
544    #[inline]
545    fn bitxor(self, rhs: Self) -> Self::Output {
546        // Safety: `self` is an integer vector
547        unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
548    }
549}
550
551impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
552where
553    T: MaskElement,
554{
555    type Output = Self;
556    #[inline]
557    fn bitxor(self, rhs: bool) -> Self::Output {
558        self ^ Self::splat(rhs)
559    }
560}
561
562impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
563where
564    T: MaskElement,
565{
566    type Output = Mask<T, N>;
567    #[inline]
568    fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
569        Mask::splat(self) ^ rhs
570    }
571}
572
573impl<T, const N: usize> core::ops::Not for Mask<T, N>
574where
575    T: MaskElement,
576{
577    type Output = Mask<T, N>;
578    #[inline]
579    fn not(self) -> Self::Output {
580        Self::splat(true) ^ self
581    }
582}
583
584impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
585where
586    T: MaskElement,
587{
588    #[inline]
589    fn bitand_assign(&mut self, rhs: Self) {
590        *self = *self & rhs;
591    }
592}
593
594impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
595where
596    T: MaskElement,
597{
598    #[inline]
599    fn bitand_assign(&mut self, rhs: bool) {
600        *self &= Self::splat(rhs);
601    }
602}
603
604impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
605where
606    T: MaskElement,
607{
608    #[inline]
609    fn bitor_assign(&mut self, rhs: Self) {
610        *self = *self | rhs;
611    }
612}
613
614impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
615where
616    T: MaskElement,
617{
618    #[inline]
619    fn bitor_assign(&mut self, rhs: bool) {
620        *self |= Self::splat(rhs);
621    }
622}
623
624impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
625where
626    T: MaskElement,
627{
628    #[inline]
629    fn bitxor_assign(&mut self, rhs: Self) {
630        *self = *self ^ rhs;
631    }
632}
633
634impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
635where
636    T: MaskElement,
637{
638    #[inline]
639    fn bitxor_assign(&mut self, rhs: bool) {
640        *self ^= Self::splat(rhs);
641    }
642}
643
644macro_rules! impl_from {
645    { $from:ty  => $($to:ty),* } => {
646        $(
647        impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
648        {
649            #[inline]
650            fn from(value: Mask<$from, N>) -> Self {
651                value.cast()
652            }
653        }
654        )*
655    }
656}
657impl_from! { i8 => i16, i32, i64, isize }
658impl_from! { i16 => i32, i64, isize, i8 }
659impl_from! { i32 => i64, isize, i8, i16 }
660impl_from! { i64 => isize, i8, i16, i32 }
661impl_from! { isize => i8, i16, i32, i64 }