1#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10 fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14 { $($int:ty),* } => {
15 $(
16 impl FixEndianness for $int {
17 #[inline(always)]
18 fn fix_endianness(self) -> Self {
19 if cfg!(target_endian = "big") {
20 <$int>::reverse_bits(self)
21 } else {
22 self
23 }
24 }
25 }
26 )*
27 }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod sealed {
33 use super::*;
34
35 pub trait Sealed {
42 fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43 where
44 Self: SimdElement;
45
46 fn eq(self, other: Self) -> bool;
47
48 fn to_usize(self) -> usize;
49 fn max_unsigned() -> u64;
50
51 type Unsigned: SimdElement;
52
53 const TRUE: Self;
54
55 const FALSE: Self;
56 }
57}
58use sealed::Sealed;
59
60pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
65
66macro_rules! impl_element {
67 { $ty:ty, $unsigned:ty } => {
68 impl Sealed for $ty {
69 #[inline]
70 fn valid<const N: usize>(value: Simd<Self, N>) -> bool
71 {
72 unsafe {
76 use core::intrinsics::simd;
77 let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
78 let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
79 let valid: Simd<Self, N> = simd::simd_or(falses, trues);
80 simd::simd_reduce_all(valid)
81 }
82 }
83
84 #[inline]
85 fn eq(self, other: Self) -> bool { self == other }
86
87 #[inline]
88 fn to_usize(self) -> usize {
89 self as usize
90 }
91
92 #[inline]
93 fn max_unsigned() -> u64 {
94 <$unsigned>::MAX as u64
95 }
96
97 type Unsigned = $unsigned;
98
99 const TRUE: Self = -1;
100 const FALSE: Self = 0;
101 }
102
103 unsafe impl MaskElement for $ty {}
105 }
106}
107
108impl_element! { i8, u8 }
109impl_element! { i16, u16 }
110impl_element! { i32, u32 }
111impl_element! { i64, u64 }
112impl_element! { isize, usize }
113
114#[repr(transparent)]
125pub struct Mask<T, const N: usize>(Simd<T, N>)
126where
127 T: MaskElement;
128
129impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
130
131impl<T, const N: usize> Clone for Mask<T, N>
132where
133 T: MaskElement,
134{
135 #[inline]
136 fn clone(&self) -> Self {
137 *self
138 }
139}
140
141impl<T, const N: usize> Mask<T, N>
142where
143 T: MaskElement,
144{
145 #[inline]
147 #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
148 pub const fn splat(value: bool) -> Self {
149 Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
150 }
151
152 #[inline]
154 pub fn from_array(array: [bool; N]) -> Self {
155 unsafe {
162 let bytes: [u8; N] = mem::transmute_copy(&array);
163 let bools: Simd<i8, N> =
164 core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
165 Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
166 }
167 }
168
169 #[inline]
171 pub fn to_array(self) -> [bool; N] {
172 unsafe {
183 let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
184 bytes &= Simd::splat(1i8);
185 mem::transmute_copy(&bytes)
186 }
187 }
188
189 #[inline]
195 #[must_use = "method returns a new mask and does not mutate the original value"]
196 pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
197 unsafe {
199 core::intrinsics::assume(<T as Sealed>::valid(value));
200 }
201 Self(value)
202 }
203
204 #[inline]
210 #[must_use = "method returns a new mask and does not mutate the original value"]
211 #[track_caller]
212 pub fn from_simd(value: Simd<T, N>) -> Self {
213 assert!(T::valid(value), "all values must be either 0 or -1",);
214 unsafe { Self::from_simd_unchecked(value) }
216 }
217
218 #[inline]
221 #[must_use = "method returns a new vector and does not mutate the original value"]
222 pub fn to_simd(self) -> Simd<T, N> {
223 self.0
224 }
225
226 #[inline]
228 #[must_use = "method returns a new mask and does not mutate the original value"]
229 pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
230 unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
232 }
233
234 #[inline]
239 #[must_use = "method returns a new bool and does not mutate the original value"]
240 pub unsafe fn test_unchecked(&self, index: usize) -> bool {
241 unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
243 }
244
245 #[inline]
250 #[must_use = "method returns a new bool and does not mutate the original value"]
251 #[track_caller]
252 pub fn test(&self, index: usize) -> bool {
253 T::eq(self.0[index], T::TRUE)
254 }
255
256 #[inline]
261 pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
262 unsafe {
264 *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
265 }
266 }
267
268 #[inline]
273 #[track_caller]
274 pub fn set(&mut self, index: usize, value: bool) {
275 self.0[index] = if value { T::TRUE } else { T::FALSE }
276 }
277
278 #[inline]
280 #[must_use = "method returns a new bool and does not mutate the original value"]
281 pub fn any(self) -> bool {
282 unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
284 }
285
286 #[inline]
288 #[must_use = "method returns a new bool and does not mutate the original value"]
289 pub fn all(self) -> bool {
290 unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
292 }
293
294 #[inline]
298 #[must_use = "method returns a new integer and does not mutate the original value"]
299 pub fn to_bitmask(self) -> u64 {
300 const {
301 assert!(N <= 64, "number of elements can't be greater than 64");
302 }
303
304 #[inline]
305 unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
306 mask: Mask<T, N>,
307 ) -> U
308 where
309 T: MaskElement,
310 {
311 let resized = mask.resize::<M>(false);
312
313 let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
315
316 bitmask.fix_endianness()
318 }
319
320 if N <= 8 {
322 unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
324 } else if N <= 16 {
325 unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
327 } else if N <= 32 {
328 unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
330 } else {
331 unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
333 }
334 }
335
336 #[inline]
341 #[must_use = "method returns a new mask and does not mutate the original value"]
342 pub fn from_bitmask(bitmask: u64) -> Self {
343 Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
344 }
345
346 #[inline]
360 #[must_use = "method returns the index and does not mutate the original value"]
361 pub fn first_set(self) -> Option<usize> {
362 if cfg!(target_feature = "sse") && N <= 64 {
364 let tz = self.to_bitmask().trailing_zeros();
365 return if tz == 64 { None } else { Some(tz as usize) };
366 }
367
368 let index: Simd<T, N> = const {
375 let mut index = [0; N];
376 let mut i = 0;
377 while i < N {
378 index[i] = i;
379 i += 1;
380 }
381 unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
383 };
384
385 let masked_index: Simd<T, N> =
387 unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
388
389 let masked_index: Simd<T::Unsigned, N> =
391 unsafe { core::intrinsics::simd::simd_cast(masked_index) };
392
393 let min_index: T::Unsigned =
395 unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
396
397 let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
399
400 if min_index.eq(T::TRUE) {
401 None
402 } else {
403 let min_index = min_index.to_usize();
404
405 unsafe {
408 core::hint::assert_unchecked(min_index < N);
409 }
410
411 Some(min_index)
412 }
413 }
414}
415
416impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
418where
419 T: MaskElement,
420{
421 #[inline]
422 fn from(array: [bool; N]) -> Self {
423 Self::from_array(array)
424 }
425}
426
427impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
428where
429 T: MaskElement,
430{
431 #[inline]
432 fn from(vector: Mask<T, N>) -> Self {
433 vector.to_array()
434 }
435}
436
437impl<T, const N: usize> Default for Mask<T, N>
438where
439 T: MaskElement,
440{
441 #[inline]
442 fn default() -> Self {
443 Self::splat(false)
444 }
445}
446
447impl<T, const N: usize> PartialEq for Mask<T, N>
448where
449 T: MaskElement + PartialEq,
450{
451 #[inline]
452 fn eq(&self, other: &Self) -> bool {
453 self.0 == other.0
454 }
455}
456
457impl<T, const N: usize> PartialOrd for Mask<T, N>
458where
459 T: MaskElement + PartialOrd,
460{
461 #[inline]
462 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
463 self.0.partial_cmp(&other.0)
464 }
465}
466
467impl<T, const N: usize> fmt::Debug for Mask<T, N>
468where
469 T: MaskElement + fmt::Debug,
470{
471 #[inline]
472 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473 f.debug_list()
474 .entries((0..N).map(|i| self.test(i)))
475 .finish()
476 }
477}
478
479impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
480where
481 T: MaskElement,
482{
483 type Output = Self;
484 #[inline]
485 fn bitand(self, rhs: Self) -> Self {
486 unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
488 }
489}
490
491impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
492where
493 T: MaskElement,
494{
495 type Output = Self;
496 #[inline]
497 fn bitand(self, rhs: bool) -> Self {
498 self & Self::splat(rhs)
499 }
500}
501
502impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
503where
504 T: MaskElement,
505{
506 type Output = Mask<T, N>;
507 #[inline]
508 fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
509 Mask::splat(self) & rhs
510 }
511}
512
513impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
514where
515 T: MaskElement,
516{
517 type Output = Self;
518 #[inline]
519 fn bitor(self, rhs: Self) -> Self {
520 unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
522 }
523}
524
525impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
526where
527 T: MaskElement,
528{
529 type Output = Self;
530 #[inline]
531 fn bitor(self, rhs: bool) -> Self {
532 self | Self::splat(rhs)
533 }
534}
535
536impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
537where
538 T: MaskElement,
539{
540 type Output = Mask<T, N>;
541 #[inline]
542 fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
543 Mask::splat(self) | rhs
544 }
545}
546
547impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
548where
549 T: MaskElement,
550{
551 type Output = Self;
552 #[inline]
553 fn bitxor(self, rhs: Self) -> Self::Output {
554 unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
556 }
557}
558
559impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
560where
561 T: MaskElement,
562{
563 type Output = Self;
564 #[inline]
565 fn bitxor(self, rhs: bool) -> Self::Output {
566 self ^ Self::splat(rhs)
567 }
568}
569
570impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
571where
572 T: MaskElement,
573{
574 type Output = Mask<T, N>;
575 #[inline]
576 fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
577 Mask::splat(self) ^ rhs
578 }
579}
580
581impl<T, const N: usize> core::ops::Not for Mask<T, N>
582where
583 T: MaskElement,
584{
585 type Output = Mask<T, N>;
586 #[inline]
587 fn not(self) -> Self::Output {
588 Self::splat(true) ^ self
589 }
590}
591
592impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
593where
594 T: MaskElement,
595{
596 #[inline]
597 fn bitand_assign(&mut self, rhs: Self) {
598 *self = *self & rhs;
599 }
600}
601
602impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
603where
604 T: MaskElement,
605{
606 #[inline]
607 fn bitand_assign(&mut self, rhs: bool) {
608 *self &= Self::splat(rhs);
609 }
610}
611
612impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
613where
614 T: MaskElement,
615{
616 #[inline]
617 fn bitor_assign(&mut self, rhs: Self) {
618 *self = *self | rhs;
619 }
620}
621
622impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
623where
624 T: MaskElement,
625{
626 #[inline]
627 fn bitor_assign(&mut self, rhs: bool) {
628 *self |= Self::splat(rhs);
629 }
630}
631
632impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
633where
634 T: MaskElement,
635{
636 #[inline]
637 fn bitxor_assign(&mut self, rhs: Self) {
638 *self = *self ^ rhs;
639 }
640}
641
642impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
643where
644 T: MaskElement,
645{
646 #[inline]
647 fn bitxor_assign(&mut self, rhs: bool) {
648 *self ^= Self::splat(rhs);
649 }
650}
651
652macro_rules! impl_from {
653 { $from:ty => $($to:ty),* } => {
654 $(
655 impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
656 {
657 #[inline]
658 fn from(value: Mask<$from, N>) -> Self {
659 value.cast()
660 }
661 }
662 )*
663 }
664}
665impl_from! { i8 => i16, i32, i64, isize }
666impl_from! { i16 => i32, i64, isize, i8 }
667impl_from! { i32 => i64, isize, i8, i16 }
668impl_from! { i64 => isize, i8, i16, i32 }
669impl_from! { isize => i8, i16, i32, i64 }