1#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10 fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14 { $($int:ty),* } => {
15 $(
16 impl FixEndianness for $int {
17 #[inline(always)]
18 fn fix_endianness(self) -> Self {
19 if cfg!(target_endian = "big") {
20 <$int>::reverse_bits(self)
21 } else {
22 self
23 }
24 }
25 }
26 )*
27 }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod private_methods {
33 use super::*;
34
35 pub impl(super) trait PrivateMethods {
42 fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43 where
44 Self: SimdElement;
45
46 fn eq(self, other: Self) -> bool;
47
48 fn to_usize(self) -> usize;
49 fn max_unsigned() -> u64;
50
51 type Unsigned: SimdElement;
52
53 const TRUE: Self;
54
55 const FALSE: Self;
56 }
57}
58use private_methods::PrivateMethods;
59
60pub impl(self) unsafe trait MaskElement:
65 SimdElement<Mask = Self> + SimdCast + PrivateMethods
66{
67}
68
69macro_rules! impl_element {
70 { $ty:ty, $unsigned:ty } => {
71 impl PrivateMethods for $ty {
72 #[inline]
73 fn valid<const N: usize>(value: Simd<Self, N>) -> bool
74 {
75 unsafe {
79 use core::intrinsics::simd;
80 let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
81 let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
82 let valid: Simd<Self, N> = simd::simd_or(falses, trues);
83 simd::simd_reduce_all(valid)
84 }
85 }
86
87 #[inline]
88 fn eq(self, other: Self) -> bool { self == other }
89
90 #[inline]
91 fn to_usize(self) -> usize {
92 self as usize
93 }
94
95 #[inline]
96 fn max_unsigned() -> u64 {
97 <$unsigned>::MAX as u64
98 }
99
100 type Unsigned = $unsigned;
101
102 const TRUE: Self = -1;
103 const FALSE: Self = 0;
104 }
105
106 unsafe impl MaskElement for $ty {}
108 }
109}
110
111impl_element! { i8, u8 }
112impl_element! { i16, u16 }
113impl_element! { i32, u32 }
114impl_element! { i64, u64 }
115impl_element! { isize, usize }
116
117#[repr(transparent)]
128pub struct Mask<T, const N: usize>(Simd<T, N>)
129where
130 T: MaskElement;
131
132impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
133
134impl<T, const N: usize> Clone for Mask<T, N>
135where
136 T: MaskElement,
137{
138 #[inline]
139 fn clone(&self) -> Self {
140 *self
141 }
142}
143
144impl<T, const N: usize> Mask<T, N>
145where
146 T: MaskElement,
147{
148 #[inline]
150 #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
151 pub const fn splat(value: bool) -> Self {
152 Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
153 }
154
155 #[inline]
157 pub fn from_array(array: [bool; N]) -> Self {
158 unsafe {
165 let bytes: [u8; N] = mem::transmute_copy(&array);
166 let bools: Simd<i8, N> =
167 core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
168 Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
169 }
170 }
171
172 #[inline]
174 pub fn to_array(self) -> [bool; N] {
175 unsafe {
186 let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
187 bytes &= Simd::splat(1i8);
188 mem::transmute_copy(&bytes)
189 }
190 }
191
192 #[inline]
198 #[must_use = "method returns a new mask and does not mutate the original value"]
199 pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
200 unsafe {
202 core::intrinsics::assume(<T as PrivateMethods>::valid(value));
203 }
204 Self(value)
205 }
206
207 #[inline]
213 #[must_use = "method returns a new mask and does not mutate the original value"]
214 #[track_caller]
215 pub fn from_simd(value: Simd<T, N>) -> Self {
216 assert!(T::valid(value), "all values must be either 0 or -1",);
217 unsafe { Self::from_simd_unchecked(value) }
219 }
220
221 #[inline]
224 #[must_use = "method returns a new vector and does not mutate the original value"]
225 pub fn to_simd(self) -> Simd<T, N> {
226 self.0
227 }
228
229 #[inline]
231 #[must_use = "method returns a new mask and does not mutate the original value"]
232 pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
233 unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
235 }
236
237 #[inline]
242 #[must_use = "method returns a new bool and does not mutate the original value"]
243 pub unsafe fn test_unchecked(&self, index: usize) -> bool {
244 unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
246 }
247
248 #[inline]
253 #[must_use = "method returns a new bool and does not mutate the original value"]
254 #[track_caller]
255 pub fn test(&self, index: usize) -> bool {
256 T::eq(self.0[index], T::TRUE)
257 }
258
259 #[inline]
264 pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
265 unsafe {
267 *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
268 }
269 }
270
271 #[inline]
276 #[track_caller]
277 pub fn set(&mut self, index: usize, value: bool) {
278 self.0[index] = if value { T::TRUE } else { T::FALSE }
279 }
280
281 #[inline]
283 #[must_use = "method returns a new bool and does not mutate the original value"]
284 pub fn any(self) -> bool {
285 unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
287 }
288
289 #[inline]
291 #[must_use = "method returns a new bool and does not mutate the original value"]
292 pub fn all(self) -> bool {
293 unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
295 }
296
297 #[inline]
301 #[must_use = "method returns a new integer and does not mutate the original value"]
302 pub fn to_bitmask(self) -> u64 {
303 const {
304 assert!(N <= 64, "number of elements can't be greater than 64");
305 }
306
307 #[inline]
308 unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
309 mask: Mask<T, N>,
310 ) -> U
311 where
312 T: MaskElement,
313 {
314 let resized = mask.resize::<M>(false);
315
316 let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
318
319 bitmask.fix_endianness()
321 }
322
323 if N <= 8 {
325 unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
327 } else if N <= 16 {
328 unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
330 } else if N <= 32 {
331 unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
333 } else {
334 unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
336 }
337 }
338
339 #[inline]
344 #[must_use = "method returns a new mask and does not mutate the original value"]
345 pub fn from_bitmask(bitmask: u64) -> Self {
346 Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
347 }
348
349 #[inline]
363 #[must_use = "method returns the index and does not mutate the original value"]
364 pub fn first_set(self) -> Option<usize> {
365 if cfg!(target_feature = "sse") && N <= 64 {
367 return self.to_bitmask().lowest_one().map(|i| i as usize);
368 }
369
370 let index: Simd<T, N> = const {
377 let mut index = [0; N];
378 let mut i = 0;
379 while i < N {
380 index[i] = i;
381 i += 1;
382 }
383 unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
385 };
386
387 let masked_index: Simd<T, N> =
389 unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
390
391 let masked_index: Simd<T::Unsigned, N> =
393 unsafe { core::intrinsics::simd::simd_cast(masked_index) };
394
395 let min_index: T::Unsigned =
397 unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
398
399 let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
401
402 if min_index.eq(T::TRUE) {
403 None
404 } else {
405 let min_index = min_index.to_usize();
406
407 unsafe {
410 core::hint::assert_unchecked(min_index < N);
411 }
412
413 Some(min_index)
414 }
415 }
416
417 #[inline]
431 #[must_use = "method returns the index and does not mutate the original value"]
432 pub fn last_set(self) -> Option<usize> {
433 if cfg!(target_feature = "sse") && N <= 64 {
435 return self.to_bitmask().highest_one().map(|i| i as usize);
436 }
437
438 let index: Simd<T, N> = const {
445 let mut index = [0; N];
446 let mut i = 0;
447 while i < N {
448 index[i] = i;
449 i += 1;
450 }
451 unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
453 };
454
455 let masked_index: Simd<T, N> =
457 unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
458
459 let max_index: T = unsafe { core::intrinsics::simd::simd_reduce_max(masked_index) };
461
462 if max_index.eq(T::TRUE) {
463 None
464 } else {
465 let max_index = max_index.to_usize();
466
467 unsafe {
470 core::hint::assert_unchecked(max_index < N);
471 }
472
473 Some(max_index)
474 }
475 }
476}
477
478impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
480where
481 T: MaskElement,
482{
483 #[inline]
484 fn from(array: [bool; N]) -> Self {
485 Self::from_array(array)
486 }
487}
488
489impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
490where
491 T: MaskElement,
492{
493 #[inline]
494 fn from(vector: Mask<T, N>) -> Self {
495 vector.to_array()
496 }
497}
498
499impl<T, const N: usize> Default for Mask<T, N>
500where
501 T: MaskElement,
502{
503 #[inline]
504 fn default() -> Self {
505 Self::splat(false)
506 }
507}
508
509impl<T, const N: usize> PartialEq for Mask<T, N>
510where
511 T: MaskElement + PartialEq,
512{
513 #[inline]
514 fn eq(&self, other: &Self) -> bool {
515 self.0 == other.0
516 }
517}
518
519impl<T, const N: usize> PartialOrd for Mask<T, N>
520where
521 T: MaskElement + PartialOrd,
522{
523 #[inline]
524 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
525 self.0.partial_cmp(&other.0)
526 }
527}
528
529impl<T, const N: usize> fmt::Debug for Mask<T, N>
530where
531 T: MaskElement + fmt::Debug,
532{
533 #[inline]
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 f.debug_list()
536 .entries((0..N).map(|i| self.test(i)))
537 .finish()
538 }
539}
540
541impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
542where
543 T: MaskElement,
544{
545 type Output = Self;
546 #[inline]
547 fn bitand(self, rhs: Self) -> Self {
548 unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
550 }
551}
552
553impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
554where
555 T: MaskElement,
556{
557 type Output = Self;
558 #[inline]
559 fn bitand(self, rhs: bool) -> Self {
560 self & Self::splat(rhs)
561 }
562}
563
564impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
565where
566 T: MaskElement,
567{
568 type Output = Mask<T, N>;
569 #[inline]
570 fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
571 Mask::splat(self) & rhs
572 }
573}
574
575impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
576where
577 T: MaskElement,
578{
579 type Output = Self;
580 #[inline]
581 fn bitor(self, rhs: Self) -> Self {
582 unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
584 }
585}
586
587impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
588where
589 T: MaskElement,
590{
591 type Output = Self;
592 #[inline]
593 fn bitor(self, rhs: bool) -> Self {
594 self | Self::splat(rhs)
595 }
596}
597
598impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
599where
600 T: MaskElement,
601{
602 type Output = Mask<T, N>;
603 #[inline]
604 fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
605 Mask::splat(self) | rhs
606 }
607}
608
609impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
610where
611 T: MaskElement,
612{
613 type Output = Self;
614 #[inline]
615 fn bitxor(self, rhs: Self) -> Self::Output {
616 unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
618 }
619}
620
621impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
622where
623 T: MaskElement,
624{
625 type Output = Self;
626 #[inline]
627 fn bitxor(self, rhs: bool) -> Self::Output {
628 self ^ Self::splat(rhs)
629 }
630}
631
632impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
633where
634 T: MaskElement,
635{
636 type Output = Mask<T, N>;
637 #[inline]
638 fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
639 Mask::splat(self) ^ rhs
640 }
641}
642
643impl<T, const N: usize> core::ops::Not for Mask<T, N>
644where
645 T: MaskElement,
646{
647 type Output = Mask<T, N>;
648 #[inline]
649 fn not(self) -> Self::Output {
650 Self::splat(true) ^ self
651 }
652}
653
654impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
655where
656 T: MaskElement,
657{
658 #[inline]
659 fn bitand_assign(&mut self, rhs: Self) {
660 *self = *self & rhs;
661 }
662}
663
664impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
665where
666 T: MaskElement,
667{
668 #[inline]
669 fn bitand_assign(&mut self, rhs: bool) {
670 *self &= Self::splat(rhs);
671 }
672}
673
674impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
675where
676 T: MaskElement,
677{
678 #[inline]
679 fn bitor_assign(&mut self, rhs: Self) {
680 *self = *self | rhs;
681 }
682}
683
684impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
685where
686 T: MaskElement,
687{
688 #[inline]
689 fn bitor_assign(&mut self, rhs: bool) {
690 *self |= Self::splat(rhs);
691 }
692}
693
694impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
695where
696 T: MaskElement,
697{
698 #[inline]
699 fn bitxor_assign(&mut self, rhs: Self) {
700 *self = *self ^ rhs;
701 }
702}
703
704impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
705where
706 T: MaskElement,
707{
708 #[inline]
709 fn bitxor_assign(&mut self, rhs: bool) {
710 *self ^= Self::splat(rhs);
711 }
712}
713
714macro_rules! impl_from {
715 { $from:ty => $($to:ty),* } => {
716 $(
717 impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
718 {
719 #[inline]
720 fn from(value: Mask<$from, N>) -> Self {
721 value.cast()
722 }
723 }
724 )*
725 }
726}
727impl_from! { i8 => i16, i32, i64, isize }
728impl_from! { i16 => i32, i64, isize, i8 }
729impl_from! { i32 => i64, isize, i8, i16 }
730impl_from! { i64 => isize, i8, i16, i32 }
731impl_from! { isize => i8, i16, i32, i64 }