core_simd/
masks.rs

1//! Types and traits associated with masking elements of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5#[cfg_attr(
6    not(all(target_arch = "x86_64", target_feature = "avx512f")),
7    path = "masks/full_masks.rs"
8)]
9#[cfg_attr(
10    all(target_arch = "x86_64", target_feature = "avx512f"),
11    path = "masks/bitmask.rs"
12)]
13mod mask_impl;
14
15use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount};
16use core::cmp::Ordering;
17use core::{fmt, mem};
18
19mod sealed {
20    use super::*;
21
22    /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
23    /// from bleeding into the parent bounds.
24    ///
25    /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
26    /// prevent us from ever removing that bound, or from implementing `MaskElement` on
27    /// non-`PartialEq` types in the future.
28    pub trait Sealed {
29        fn valid<const N: usize>(values: Simd<Self, N>) -> bool
30        where
31            LaneCount<N>: SupportedLaneCount,
32            Self: SimdElement;
33
34        fn eq(self, other: Self) -> bool;
35
36        fn to_usize(self) -> usize;
37        fn max_unsigned() -> u64;
38
39        type Unsigned: SimdElement;
40
41        const TRUE: Self;
42
43        const FALSE: Self;
44    }
45}
46use sealed::Sealed;
47
48/// Marker trait for types that may be used as SIMD mask elements.
49///
50/// # Safety
51/// Type must be a signed integer.
52pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
53
54macro_rules! impl_element {
55    { $ty:ty, $unsigned:ty } => {
56        impl Sealed for $ty {
57            #[inline]
58            fn valid<const N: usize>(value: Simd<Self, N>) -> bool
59            where
60                LaneCount<N>: SupportedLaneCount,
61            {
62                // We can't use `Simd` directly, because `Simd`'s functions call this function and
63                // we will end up with an infinite loop.
64                // Safety: `value` is an integer vector
65                unsafe {
66                    use core::intrinsics::simd;
67                    let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
68                    let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
69                    let valid: Simd<Self, N> = simd::simd_or(falses, trues);
70                    simd::simd_reduce_all(valid)
71                }
72            }
73
74            #[inline]
75            fn eq(self, other: Self) -> bool { self == other }
76
77            #[inline]
78            fn to_usize(self) -> usize {
79                self as usize
80            }
81
82            #[inline]
83            fn max_unsigned() -> u64 {
84                <$unsigned>::MAX as u64
85            }
86
87            type Unsigned = $unsigned;
88
89            const TRUE: Self = -1;
90            const FALSE: Self = 0;
91        }
92
93        // Safety: this is a valid mask element type
94        unsafe impl MaskElement for $ty {}
95    }
96}
97
98impl_element! { i8, u8 }
99impl_element! { i16, u16 }
100impl_element! { i32, u32 }
101impl_element! { i64, u64 }
102impl_element! { isize, usize }
103
104/// A SIMD vector mask for `N` elements of width specified by `Element`.
105///
106/// Masks represent boolean inclusion/exclusion on a per-element basis.
107///
108/// The layout of this type is unspecified, and may change between platforms
109/// and/or Rust versions, and code should not assume that it is equivalent to
110/// `[T; N]`.
111#[repr(transparent)]
112pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
113where
114    T: MaskElement,
115    LaneCount<N>: SupportedLaneCount;
116
117impl<T, const N: usize> Copy for Mask<T, N>
118where
119    T: MaskElement,
120    LaneCount<N>: SupportedLaneCount,
121{
122}
123
124impl<T, const N: usize> Clone for Mask<T, N>
125where
126    T: MaskElement,
127    LaneCount<N>: SupportedLaneCount,
128{
129    #[inline]
130    fn clone(&self) -> Self {
131        *self
132    }
133}
134
135impl<T, const N: usize> Mask<T, N>
136where
137    T: MaskElement,
138    LaneCount<N>: SupportedLaneCount,
139{
140    /// Constructs a mask by setting all elements to the given value.
141    #[inline]
142    #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
143    pub const fn splat(value: bool) -> Self {
144        Self(mask_impl::Mask::splat(value))
145    }
146
147    /// Converts an array of bools to a SIMD mask.
148    #[inline]
149    pub fn from_array(array: [bool; N]) -> Self {
150        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
151        //     true:    0b_0000_0001
152        //     false:   0b_0000_0000
153        // Thus, an array of bools is also a valid array of bytes: [u8; N]
154        // This would be hypothetically valid as an "in-place" transmute,
155        // but these are "dependently-sized" types, so copy elision it is!
156        unsafe {
157            let bytes: [u8; N] = mem::transmute_copy(&array);
158            let bools: Simd<i8, N> =
159                core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
160            Mask::from_int_unchecked(core::intrinsics::simd::simd_cast(bools))
161        }
162    }
163
164    /// Converts a SIMD mask to an array of bools.
165    #[inline]
166    pub fn to_array(self) -> [bool; N] {
167        // This follows mostly the same logic as from_array.
168        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
169        //     true:    0b_0000_0001
170        //     false:   0b_0000_0000
171        // Thus, an array of bools is also a valid array of bytes: [u8; N]
172        // Since our masks are equal to integers where all bits are set,
173        // we can simply convert them to i8s, and then bitand them by the
174        // bitpattern for Rust's "true" bool.
175        // This would be hypothetically valid as an "in-place" transmute,
176        // but these are "dependently-sized" types, so copy elision it is!
177        unsafe {
178            let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_int());
179            bytes &= Simd::splat(1i8);
180            mem::transmute_copy(&bytes)
181        }
182    }
183
184    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
185    /// represents `true`.
186    ///
187    /// # Safety
188    /// All elements must be either 0 or -1.
189    #[inline]
190    #[must_use = "method returns a new mask and does not mutate the original value"]
191    pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
192        // Safety: the caller must confirm this invariant
193        unsafe {
194            core::intrinsics::assume(<T as Sealed>::valid(value));
195            Self(mask_impl::Mask::from_int_unchecked(value))
196        }
197    }
198
199    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
200    /// represents `true`.
201    ///
202    /// # Panics
203    /// Panics if any element is not 0 or -1.
204    #[inline]
205    #[must_use = "method returns a new mask and does not mutate the original value"]
206    #[track_caller]
207    pub fn from_int(value: Simd<T, N>) -> Self {
208        assert!(T::valid(value), "all values must be either 0 or -1",);
209        // Safety: the validity has been checked
210        unsafe { Self::from_int_unchecked(value) }
211    }
212
213    /// Converts the mask to a vector of integers, where 0 represents `false` and -1
214    /// represents `true`.
215    #[inline]
216    #[must_use = "method returns a new vector and does not mutate the original value"]
217    pub fn to_int(self) -> Simd<T, N> {
218        self.0.to_int()
219    }
220
221    /// Converts the mask to a mask of any other element size.
222    #[inline]
223    #[must_use = "method returns a new mask and does not mutate the original value"]
224    pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
225        Mask(self.0.convert())
226    }
227
228    /// Tests the value of the specified element.
229    ///
230    /// # Safety
231    /// `index` must be less than `self.len()`.
232    #[inline]
233    #[must_use = "method returns a new bool and does not mutate the original value"]
234    pub unsafe fn test_unchecked(&self, index: usize) -> bool {
235        // Safety: the caller must confirm this invariant
236        unsafe { self.0.test_unchecked(index) }
237    }
238
239    /// Tests the value of the specified element.
240    ///
241    /// # Panics
242    /// Panics if `index` is greater than or equal to the number of elements in the vector.
243    #[inline]
244    #[must_use = "method returns a new bool and does not mutate the original value"]
245    #[track_caller]
246    pub fn test(&self, index: usize) -> bool {
247        assert!(index < N, "element index out of range");
248        // Safety: the element index has been checked
249        unsafe { self.test_unchecked(index) }
250    }
251
252    /// Sets the value of the specified element.
253    ///
254    /// # Safety
255    /// `index` must be less than `self.len()`.
256    #[inline]
257    pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
258        // Safety: the caller must confirm this invariant
259        unsafe {
260            self.0.set_unchecked(index, value);
261        }
262    }
263
264    /// Sets the value of the specified element.
265    ///
266    /// # Panics
267    /// Panics if `index` is greater than or equal to the number of elements in the vector.
268    #[inline]
269    #[track_caller]
270    pub fn set(&mut self, index: usize, value: bool) {
271        assert!(index < N, "element index out of range");
272        // Safety: the element index has been checked
273        unsafe {
274            self.set_unchecked(index, value);
275        }
276    }
277
278    /// Returns true if any element is set, or false otherwise.
279    #[inline]
280    #[must_use = "method returns a new bool and does not mutate the original value"]
281    pub fn any(self) -> bool {
282        self.0.any()
283    }
284
285    /// Returns true if all elements are set, or false otherwise.
286    #[inline]
287    #[must_use = "method returns a new bool and does not mutate the original value"]
288    pub fn all(self) -> bool {
289        self.0.all()
290    }
291
292    /// Creates a bitmask from a mask.
293    ///
294    /// Each bit is set if the corresponding element in the mask is `true`.
295    /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
296    #[inline]
297    #[must_use = "method returns a new integer and does not mutate the original value"]
298    pub fn to_bitmask(self) -> u64 {
299        self.0.to_bitmask_integer()
300    }
301
302    /// Creates a mask from a bitmask.
303    ///
304    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
305    /// If the mask contains more than 64 elements, the remainder are set to `false`.
306    #[inline]
307    #[must_use = "method returns a new mask and does not mutate the original value"]
308    pub fn from_bitmask(bitmask: u64) -> Self {
309        Self(mask_impl::Mask::from_bitmask_integer(bitmask))
310    }
311
312    /// Finds the index of the first set element.
313    ///
314    /// ```
315    /// # #![feature(portable_simd)]
316    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
317    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
318    /// # use simd::mask32x8;
319    /// assert_eq!(mask32x8::splat(false).first_set(), None);
320    /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
321    ///
322    /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
323    /// assert_eq!(mask.first_set(), Some(1));
324    /// ```
325    #[inline]
326    #[must_use = "method returns the index and does not mutate the original value"]
327    pub fn first_set(self) -> Option<usize> {
328        // If bitmasks are efficient, using them is better
329        if cfg!(target_feature = "sse") && N <= 64 {
330            let tz = self.to_bitmask().trailing_zeros();
331            return if tz == 64 { None } else { Some(tz as usize) };
332        }
333
334        // To find the first set index:
335        // * create a vector 0..N
336        // * replace unset mask elements in that vector with -1
337        // * perform _unsigned_ reduce-min
338        // * check if the result is -1 or an index
339
340        let index = Simd::from_array(
341            const {
342                let mut index = [0; N];
343                let mut i = 0;
344                while i < N {
345                    index[i] = i;
346                    i += 1;
347                }
348                index
349            },
350        );
351
352        // Safety: the input and output are integer vectors
353        let index: Simd<T, N> = unsafe { core::intrinsics::simd::simd_cast(index) };
354
355        let masked_index = self.select(index, Self::splat(true).to_int());
356
357        // Safety: the input and output are integer vectors
358        let masked_index: Simd<T::Unsigned, N> =
359            unsafe { core::intrinsics::simd::simd_cast(masked_index) };
360
361        // Safety: the input is an integer vector
362        let min_index: T::Unsigned =
363            unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
364
365        // Safety: the return value is the unsigned version of T
366        let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
367
368        if min_index.eq(T::TRUE) {
369            None
370        } else {
371            Some(min_index.to_usize())
372        }
373    }
374}
375
376// vector/array conversion
377impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
378where
379    T: MaskElement,
380    LaneCount<N>: SupportedLaneCount,
381{
382    #[inline]
383    fn from(array: [bool; N]) -> Self {
384        Self::from_array(array)
385    }
386}
387
388impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
389where
390    T: MaskElement,
391    LaneCount<N>: SupportedLaneCount,
392{
393    #[inline]
394    fn from(vector: Mask<T, N>) -> Self {
395        vector.to_array()
396    }
397}
398
399impl<T, const N: usize> Default for Mask<T, N>
400where
401    T: MaskElement,
402    LaneCount<N>: SupportedLaneCount,
403{
404    #[inline]
405    fn default() -> Self {
406        Self::splat(false)
407    }
408}
409
410impl<T, const N: usize> PartialEq for Mask<T, N>
411where
412    T: MaskElement + PartialEq,
413    LaneCount<N>: SupportedLaneCount,
414{
415    #[inline]
416    fn eq(&self, other: &Self) -> bool {
417        self.0 == other.0
418    }
419}
420
421impl<T, const N: usize> PartialOrd for Mask<T, N>
422where
423    T: MaskElement + PartialOrd,
424    LaneCount<N>: SupportedLaneCount,
425{
426    #[inline]
427    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
428        self.0.partial_cmp(&other.0)
429    }
430}
431
432impl<T, const N: usize> fmt::Debug for Mask<T, N>
433where
434    T: MaskElement + fmt::Debug,
435    LaneCount<N>: SupportedLaneCount,
436{
437    #[inline]
438    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439        f.debug_list()
440            .entries((0..N).map(|i| self.test(i)))
441            .finish()
442    }
443}
444
445impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
446where
447    T: MaskElement,
448    LaneCount<N>: SupportedLaneCount,
449{
450    type Output = Self;
451    #[inline]
452    fn bitand(self, rhs: Self) -> Self {
453        Self(self.0 & rhs.0)
454    }
455}
456
457impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
458where
459    T: MaskElement,
460    LaneCount<N>: SupportedLaneCount,
461{
462    type Output = Self;
463    #[inline]
464    fn bitand(self, rhs: bool) -> Self {
465        self & Self::splat(rhs)
466    }
467}
468
469impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
470where
471    T: MaskElement,
472    LaneCount<N>: SupportedLaneCount,
473{
474    type Output = Mask<T, N>;
475    #[inline]
476    fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
477        Mask::splat(self) & rhs
478    }
479}
480
481impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
482where
483    T: MaskElement,
484    LaneCount<N>: SupportedLaneCount,
485{
486    type Output = Self;
487    #[inline]
488    fn bitor(self, rhs: Self) -> Self {
489        Self(self.0 | rhs.0)
490    }
491}
492
493impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
494where
495    T: MaskElement,
496    LaneCount<N>: SupportedLaneCount,
497{
498    type Output = Self;
499    #[inline]
500    fn bitor(self, rhs: bool) -> Self {
501        self | Self::splat(rhs)
502    }
503}
504
505impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
506where
507    T: MaskElement,
508    LaneCount<N>: SupportedLaneCount,
509{
510    type Output = Mask<T, N>;
511    #[inline]
512    fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
513        Mask::splat(self) | rhs
514    }
515}
516
517impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
518where
519    T: MaskElement,
520    LaneCount<N>: SupportedLaneCount,
521{
522    type Output = Self;
523    #[inline]
524    fn bitxor(self, rhs: Self) -> Self::Output {
525        Self(self.0 ^ rhs.0)
526    }
527}
528
529impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
530where
531    T: MaskElement,
532    LaneCount<N>: SupportedLaneCount,
533{
534    type Output = Self;
535    #[inline]
536    fn bitxor(self, rhs: bool) -> Self::Output {
537        self ^ Self::splat(rhs)
538    }
539}
540
541impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
542where
543    T: MaskElement,
544    LaneCount<N>: SupportedLaneCount,
545{
546    type Output = Mask<T, N>;
547    #[inline]
548    fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
549        Mask::splat(self) ^ rhs
550    }
551}
552
553impl<T, const N: usize> core::ops::Not for Mask<T, N>
554where
555    T: MaskElement,
556    LaneCount<N>: SupportedLaneCount,
557{
558    type Output = Mask<T, N>;
559    #[inline]
560    fn not(self) -> Self::Output {
561        Self(!self.0)
562    }
563}
564
565impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
566where
567    T: MaskElement,
568    LaneCount<N>: SupportedLaneCount,
569{
570    #[inline]
571    fn bitand_assign(&mut self, rhs: Self) {
572        self.0 = self.0 & rhs.0;
573    }
574}
575
576impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
577where
578    T: MaskElement,
579    LaneCount<N>: SupportedLaneCount,
580{
581    #[inline]
582    fn bitand_assign(&mut self, rhs: bool) {
583        *self &= Self::splat(rhs);
584    }
585}
586
587impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
588where
589    T: MaskElement,
590    LaneCount<N>: SupportedLaneCount,
591{
592    #[inline]
593    fn bitor_assign(&mut self, rhs: Self) {
594        self.0 = self.0 | rhs.0;
595    }
596}
597
598impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
599where
600    T: MaskElement,
601    LaneCount<N>: SupportedLaneCount,
602{
603    #[inline]
604    fn bitor_assign(&mut self, rhs: bool) {
605        *self |= Self::splat(rhs);
606    }
607}
608
609impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
610where
611    T: MaskElement,
612    LaneCount<N>: SupportedLaneCount,
613{
614    #[inline]
615    fn bitxor_assign(&mut self, rhs: Self) {
616        self.0 = self.0 ^ rhs.0;
617    }
618}
619
620impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
621where
622    T: MaskElement,
623    LaneCount<N>: SupportedLaneCount,
624{
625    #[inline]
626    fn bitxor_assign(&mut self, rhs: bool) {
627        *self ^= Self::splat(rhs);
628    }
629}
630
631macro_rules! impl_from {
632    { $from:ty  => $($to:ty),* } => {
633        $(
634        impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
635        where
636            LaneCount<N>: SupportedLaneCount,
637        {
638            #[inline]
639            fn from(value: Mask<$from, N>) -> Self {
640                value.cast()
641            }
642        }
643        )*
644    }
645}
646impl_from! { i8 => i16, i32, i64, isize }
647impl_from! { i16 => i32, i64, isize, i8 }
648impl_from! { i32 => i64, isize, i8, i16 }
649impl_from! { i64 => isize, i8, i16, i32 }
650impl_from! { isize => i8, i16, i32, i64 }