Skip to main content

core_simd/
masks.rs

1//! Types and traits associated with masking elements of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10    fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14    { $($int:ty),* } => {
15        $(
16        impl FixEndianness for $int {
17            #[inline(always)]
18            fn fix_endianness(self) -> Self {
19                if cfg!(target_endian = "big") {
20                    <$int>::reverse_bits(self)
21                } else {
22                    self
23                }
24            }
25        }
26        )*
27    }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod private_methods {
33    use super::*;
34
35    /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
36    /// from bleeding into the parent bounds.
37    ///
38    /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
39    /// prevent us from ever removing that bound, or from implementing `MaskElement` on
40    /// non-`PartialEq` types in the future.
41    pub impl(super) trait PrivateMethods {
42        fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43        where
44            Self: SimdElement;
45
46        fn eq(self, other: Self) -> bool;
47
48        fn to_usize(self) -> usize;
49        fn max_unsigned() -> u64;
50
51        type Unsigned: SimdElement;
52
53        const TRUE: Self;
54
55        const FALSE: Self;
56    }
57}
58use private_methods::PrivateMethods;
59
60/// Marker trait for types that may be used as SIMD mask elements.
61///
62/// # Safety
63/// Type must be a signed integer.
64pub impl(self) unsafe trait MaskElement:
65    SimdElement<Mask = Self> + SimdCast + PrivateMethods
66{
67}
68
69macro_rules! impl_element {
70    { $ty:ty, $unsigned:ty } => {
71        impl PrivateMethods for $ty {
72            #[inline]
73            fn valid<const N: usize>(value: Simd<Self, N>) -> bool
74            {
75                // We can't use `Simd` directly, because `Simd`'s functions call this function and
76                // we will end up with an infinite loop.
77                // Safety: `value` is an integer vector
78                unsafe {
79                    use core::intrinsics::simd;
80                    let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
81                    let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
82                    let valid: Simd<Self, N> = simd::simd_or(falses, trues);
83                    simd::simd_reduce_all(valid)
84                }
85            }
86
87            #[inline]
88            fn eq(self, other: Self) -> bool { self == other }
89
90            #[inline]
91            fn to_usize(self) -> usize {
92                self as usize
93            }
94
95            #[inline]
96            fn max_unsigned() -> u64 {
97                <$unsigned>::MAX as u64
98            }
99
100            type Unsigned = $unsigned;
101
102            const TRUE: Self = -1;
103            const FALSE: Self = 0;
104        }
105
106        // Safety: this is a valid mask element type
107        unsafe impl MaskElement for $ty {}
108    }
109}
110
111impl_element! { i8, u8 }
112impl_element! { i16, u16 }
113impl_element! { i32, u32 }
114impl_element! { i64, u64 }
115impl_element! { isize, usize }
116
117/// A SIMD vector mask for `N` elements of width specified by `Element`.
118///
119/// Masks represent boolean inclusion/exclusion on a per-element basis.
120///
121/// The layout of this type is unspecified, and may change between platforms
122/// and/or Rust versions, and code should not assume that it is equivalent to
123/// `[T; N]`.
124///
125/// `N` cannot be 0 and may be at most 64. This limit may be increased in
126/// the future.
127#[repr(transparent)]
128pub struct Mask<T, const N: usize>(Simd<T, N>)
129where
130    T: MaskElement;
131
132impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
133
134impl<T, const N: usize> Clone for Mask<T, N>
135where
136    T: MaskElement,
137{
138    #[inline]
139    fn clone(&self) -> Self {
140        *self
141    }
142}
143
144impl<T, const N: usize> Mask<T, N>
145where
146    T: MaskElement,
147{
148    /// Constructs a mask by setting all elements to the given value.
149    #[inline]
150    #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
151    pub const fn splat(value: bool) -> Self {
152        Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
153    }
154
155    /// Converts an array of bools to a SIMD mask.
156    #[inline]
157    pub fn from_array(array: [bool; N]) -> Self {
158        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
159        //     true:    0b_0000_0001
160        //     false:   0b_0000_0000
161        // Thus, an array of bools is also a valid array of bytes: [u8; N]
162        // This would be hypothetically valid as an "in-place" transmute,
163        // but these are "dependently-sized" types, so copy elision it is!
164        unsafe {
165            let bytes: [u8; N] = mem::transmute_copy(&array);
166            let bools: Simd<i8, N> =
167                core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
168            Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
169        }
170    }
171
172    /// Converts a SIMD mask to an array of bools.
173    #[inline]
174    pub fn to_array(self) -> [bool; N] {
175        // This follows mostly the same logic as from_array.
176        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
177        //     true:    0b_0000_0001
178        //     false:   0b_0000_0000
179        // Thus, an array of bools is also a valid array of bytes: [u8; N]
180        // Since our masks are equal to integers where all bits are set,
181        // we can simply convert them to i8s, and then bitand them by the
182        // bitpattern for Rust's "true" bool.
183        // This would be hypothetically valid as an "in-place" transmute,
184        // but these are "dependently-sized" types, so copy elision it is!
185        unsafe {
186            let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
187            bytes &= Simd::splat(1i8);
188            mem::transmute_copy(&bytes)
189        }
190    }
191
192    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
193    /// represents `true`.
194    ///
195    /// # Safety
196    /// All elements must be either 0 or -1.
197    #[inline]
198    #[must_use = "method returns a new mask and does not mutate the original value"]
199    pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
200        // Safety: the caller must confirm this invariant
201        unsafe {
202            core::intrinsics::assume(<T as PrivateMethods>::valid(value));
203        }
204        Self(value)
205    }
206
207    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
208    /// represents `true`.
209    ///
210    /// # Panics
211    /// Panics if any element is not 0 or -1.
212    #[inline]
213    #[must_use = "method returns a new mask and does not mutate the original value"]
214    #[track_caller]
215    pub fn from_simd(value: Simd<T, N>) -> Self {
216        assert!(T::valid(value), "all values must be either 0 or -1",);
217        // Safety: the validity has been checked
218        unsafe { Self::from_simd_unchecked(value) }
219    }
220
221    /// Converts the mask to a vector of integers, where 0 represents `false` and -1
222    /// represents `true`.
223    #[inline]
224    #[must_use = "method returns a new vector and does not mutate the original value"]
225    pub fn to_simd(self) -> Simd<T, N> {
226        self.0
227    }
228
229    /// Converts the mask to a mask of any other element size.
230    #[inline]
231    #[must_use = "method returns a new mask and does not mutate the original value"]
232    pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
233        // Safety: mask elements are integers
234        unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
235    }
236
237    /// Tests the value of the specified element.
238    ///
239    /// # Safety
240    /// `index` must be less than `self.len()`.
241    #[inline]
242    #[must_use = "method returns a new bool and does not mutate the original value"]
243    pub unsafe fn test_unchecked(&self, index: usize) -> bool {
244        // Safety: the caller must confirm this invariant
245        unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
246    }
247
248    /// Tests the value of the specified element.
249    ///
250    /// # Panics
251    /// Panics if `index` is greater than or equal to the number of elements in the vector.
252    #[inline]
253    #[must_use = "method returns a new bool and does not mutate the original value"]
254    #[track_caller]
255    pub fn test(&self, index: usize) -> bool {
256        T::eq(self.0[index], T::TRUE)
257    }
258
259    /// Sets the value of the specified element.
260    ///
261    /// # Safety
262    /// `index` must be less than `self.len()`.
263    #[inline]
264    pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
265        // Safety: the caller must confirm this invariant
266        unsafe {
267            *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
268        }
269    }
270
271    /// Sets the value of the specified element.
272    ///
273    /// # Panics
274    /// Panics if `index` is greater than or equal to the number of elements in the vector.
275    #[inline]
276    #[track_caller]
277    pub fn set(&mut self, index: usize, value: bool) {
278        self.0[index] = if value { T::TRUE } else { T::FALSE }
279    }
280
281    /// Returns true if any element is set, or false otherwise.
282    #[inline]
283    #[must_use = "method returns a new bool and does not mutate the original value"]
284    pub fn any(self) -> bool {
285        // Safety: `self` is a mask vector
286        unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
287    }
288
289    /// Returns true if all elements are set, or false otherwise.
290    #[inline]
291    #[must_use = "method returns a new bool and does not mutate the original value"]
292    pub fn all(self) -> bool {
293        // Safety: `self` is a mask vector
294        unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
295    }
296
297    /// Creates a bitmask from a mask.
298    ///
299    /// Each bit is set if the corresponding element in the mask is `true`.
300    #[inline]
301    #[must_use = "method returns a new integer and does not mutate the original value"]
302    pub fn to_bitmask(self) -> u64 {
303        const {
304            assert!(N <= 64, "number of elements can't be greater than 64");
305        }
306
307        #[inline]
308        unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
309            mask: Mask<T, N>,
310        ) -> U
311        where
312            T: MaskElement,
313        {
314            let resized = mask.resize::<M>(false);
315
316            // Safety: `resized` is an integer vector with length M, which must match T
317            let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
318
319            // LLVM assumes bit order should match endianness
320            bitmask.fix_endianness()
321        }
322
323        // TODO modify simd_bitmask to zero-extend output, making this unnecessary
324        if N <= 8 {
325            // Safety: bitmask matches length
326            unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
327        } else if N <= 16 {
328            // Safety: bitmask matches length
329            unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
330        } else if N <= 32 {
331            // Safety: bitmask matches length
332            unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
333        } else {
334            // Safety: bitmask matches length
335            unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
336        }
337    }
338
339    /// Creates a mask from a bitmask.
340    ///
341    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
342    /// If the mask contains more than 64 elements, the remainder are set to `false`.
343    #[inline]
344    #[must_use = "method returns a new mask and does not mutate the original value"]
345    pub fn from_bitmask(bitmask: u64) -> Self {
346        Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
347    }
348
349    /// Finds the index of the first set element.
350    ///
351    /// ```
352    /// # #![feature(portable_simd)]
353    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
354    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
355    /// # use simd::mask32x8;
356    /// assert_eq!(mask32x8::splat(false).first_set(), None);
357    /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
358    ///
359    /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
360    /// assert_eq!(mask.first_set(), Some(1));
361    /// ```
362    #[inline]
363    #[must_use = "method returns the index and does not mutate the original value"]
364    pub fn first_set(self) -> Option<usize> {
365        // If bitmasks are efficient, using them is better
366        if cfg!(target_feature = "sse") && N <= 64 {
367            return self.to_bitmask().lowest_one().map(|i| i as usize);
368        }
369
370        // To find the first set index:
371        // * create a vector 0..N
372        // * replace unset mask elements in that vector with -1
373        // * perform _unsigned_ reduce-min
374        // * check if the result is -1 or an index
375
376        let index: Simd<T, N> = const {
377            let mut index = [0; N];
378            let mut i = 0;
379            while i < N {
380                index[i] = i;
381                i += 1;
382            }
383            // Safety: the input and output are integer vectors
384            unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
385        };
386
387        // Safety: the input and output are integer vectors
388        let masked_index: Simd<T, N> =
389            unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
390
391        // Safety: the input and output are integer vectors
392        let masked_index: Simd<T::Unsigned, N> =
393            unsafe { core::intrinsics::simd::simd_cast(masked_index) };
394
395        // Safety: the input is an integer vector
396        let min_index: T::Unsigned =
397            unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
398
399        // Safety: the return value is the unsigned version of T
400        let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
401
402        if min_index.eq(T::TRUE) {
403            None
404        } else {
405            let min_index = min_index.to_usize();
406
407            // Allow eliminating bounds checks when using the index
408            // Safety: the index can't exceed the number of elements in the vector
409            unsafe {
410                core::hint::assert_unchecked(min_index < N);
411            }
412
413            Some(min_index)
414        }
415    }
416
417    /// Finds the index of the last set element.
418    ///
419    /// ```
420    /// # #![feature(portable_simd)]
421    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
422    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
423    /// # use simd::mask32x8;
424    /// assert_eq!(mask32x8::splat(false).last_set(), None);
425    /// assert_eq!(mask32x8::splat(true).last_set(), Some(7));
426    ///
427    /// let mask = mask32x8::from_array([false, true, false, false, true, false, true, false]);
428    /// assert_eq!(mask.last_set(), Some(6));
429    /// ```
430    #[inline]
431    #[must_use = "method returns the index and does not mutate the original value"]
432    pub fn last_set(self) -> Option<usize> {
433        // If bitmasks are efficient, using them is better
434        if cfg!(target_feature = "sse") && N <= 64 {
435            return self.to_bitmask().highest_one().map(|i| i as usize);
436        }
437
438        // To find the first set index:
439        // * create a vector 0..N
440        // * replace unset mask elements in that vector with -1
441        // * perform _signed_ reduce-max
442        // * check if the result is -1 or an index
443
444        let index: Simd<T, N> = const {
445            let mut index = [0; N];
446            let mut i = 0;
447            while i < N {
448                index[i] = i;
449                i += 1;
450            }
451            // Safety: the input and output are integer vectors
452            unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
453        };
454
455        // Safety: the input and output are integer vectors
456        let masked_index: Simd<T, N> =
457            unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
458
459        // Safety: the input is an integer vector
460        let max_index: T = unsafe { core::intrinsics::simd::simd_reduce_max(masked_index) };
461
462        if max_index.eq(T::TRUE) {
463            None
464        } else {
465            let max_index = max_index.to_usize();
466
467            // Allow eliminating bounds checks when using the index
468            // Safety: the index can't exceed the number of elements in the vector
469            unsafe {
470                core::hint::assert_unchecked(max_index < N);
471            }
472
473            Some(max_index)
474        }
475    }
476}
477
478// vector/array conversion
479impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
480where
481    T: MaskElement,
482{
483    #[inline]
484    fn from(array: [bool; N]) -> Self {
485        Self::from_array(array)
486    }
487}
488
489impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
490where
491    T: MaskElement,
492{
493    #[inline]
494    fn from(vector: Mask<T, N>) -> Self {
495        vector.to_array()
496    }
497}
498
499impl<T, const N: usize> Default for Mask<T, N>
500where
501    T: MaskElement,
502{
503    #[inline]
504    fn default() -> Self {
505        Self::splat(false)
506    }
507}
508
509impl<T, const N: usize> PartialEq for Mask<T, N>
510where
511    T: MaskElement + PartialEq,
512{
513    #[inline]
514    fn eq(&self, other: &Self) -> bool {
515        self.0 == other.0
516    }
517}
518
519impl<T, const N: usize> PartialOrd for Mask<T, N>
520where
521    T: MaskElement + PartialOrd,
522{
523    #[inline]
524    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
525        self.0.partial_cmp(&other.0)
526    }
527}
528
529impl<T, const N: usize> fmt::Debug for Mask<T, N>
530where
531    T: MaskElement + fmt::Debug,
532{
533    #[inline]
534    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535        f.debug_list()
536            .entries((0..N).map(|i| self.test(i)))
537            .finish()
538    }
539}
540
541impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
542where
543    T: MaskElement,
544{
545    type Output = Self;
546    #[inline]
547    fn bitand(self, rhs: Self) -> Self {
548        // Safety: `self` is an integer vector
549        unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
550    }
551}
552
553impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
554where
555    T: MaskElement,
556{
557    type Output = Self;
558    #[inline]
559    fn bitand(self, rhs: bool) -> Self {
560        self & Self::splat(rhs)
561    }
562}
563
564impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
565where
566    T: MaskElement,
567{
568    type Output = Mask<T, N>;
569    #[inline]
570    fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
571        Mask::splat(self) & rhs
572    }
573}
574
575impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
576where
577    T: MaskElement,
578{
579    type Output = Self;
580    #[inline]
581    fn bitor(self, rhs: Self) -> Self {
582        // Safety: `self` is an integer vector
583        unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
584    }
585}
586
587impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
588where
589    T: MaskElement,
590{
591    type Output = Self;
592    #[inline]
593    fn bitor(self, rhs: bool) -> Self {
594        self | Self::splat(rhs)
595    }
596}
597
598impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
599where
600    T: MaskElement,
601{
602    type Output = Mask<T, N>;
603    #[inline]
604    fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
605        Mask::splat(self) | rhs
606    }
607}
608
609impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
610where
611    T: MaskElement,
612{
613    type Output = Self;
614    #[inline]
615    fn bitxor(self, rhs: Self) -> Self::Output {
616        // Safety: `self` is an integer vector
617        unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
618    }
619}
620
621impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
622where
623    T: MaskElement,
624{
625    type Output = Self;
626    #[inline]
627    fn bitxor(self, rhs: bool) -> Self::Output {
628        self ^ Self::splat(rhs)
629    }
630}
631
632impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
633where
634    T: MaskElement,
635{
636    type Output = Mask<T, N>;
637    #[inline]
638    fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
639        Mask::splat(self) ^ rhs
640    }
641}
642
643impl<T, const N: usize> core::ops::Not for Mask<T, N>
644where
645    T: MaskElement,
646{
647    type Output = Mask<T, N>;
648    #[inline]
649    fn not(self) -> Self::Output {
650        Self::splat(true) ^ self
651    }
652}
653
654impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
655where
656    T: MaskElement,
657{
658    #[inline]
659    fn bitand_assign(&mut self, rhs: Self) {
660        *self = *self & rhs;
661    }
662}
663
664impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
665where
666    T: MaskElement,
667{
668    #[inline]
669    fn bitand_assign(&mut self, rhs: bool) {
670        *self &= Self::splat(rhs);
671    }
672}
673
674impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
675where
676    T: MaskElement,
677{
678    #[inline]
679    fn bitor_assign(&mut self, rhs: Self) {
680        *self = *self | rhs;
681    }
682}
683
684impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
685where
686    T: MaskElement,
687{
688    #[inline]
689    fn bitor_assign(&mut self, rhs: bool) {
690        *self |= Self::splat(rhs);
691    }
692}
693
694impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
695where
696    T: MaskElement,
697{
698    #[inline]
699    fn bitxor_assign(&mut self, rhs: Self) {
700        *self = *self ^ rhs;
701    }
702}
703
704impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
705where
706    T: MaskElement,
707{
708    #[inline]
709    fn bitxor_assign(&mut self, rhs: bool) {
710        *self ^= Self::splat(rhs);
711    }
712}
713
714macro_rules! impl_from {
715    { $from:ty  => $($to:ty),* } => {
716        $(
717        impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
718        {
719            #[inline]
720            fn from(value: Mask<$from, N>) -> Self {
721                value.cast()
722            }
723        }
724        )*
725    }
726}
727impl_from! { i8 => i16, i32, i64, isize }
728impl_from! { i16 => i32, i64, isize, i8 }
729impl_from! { i32 => i64, isize, i8, i16 }
730impl_from! { i64 => isize, i8, i16, i32 }
731impl_from! { isize => i8, i16, i32, i64 }