1#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10 fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14 { $($int:ty),* } => {
15 $(
16 impl FixEndianness for $int {
17 #[inline(always)]
18 fn fix_endianness(self) -> Self {
19 if cfg!(target_endian = "big") {
20 <$int>::reverse_bits(self)
21 } else {
22 self
23 }
24 }
25 }
26 )*
27 }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod sealed {
33 use super::*;
34
35 pub trait Sealed {
42 fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43 where
44 Self: SimdElement;
45
46 fn eq(self, other: Self) -> bool;
47
48 fn to_usize(self) -> usize;
49 fn max_unsigned() -> u64;
50
51 type Unsigned: SimdElement;
52
53 const TRUE: Self;
54
55 const FALSE: Self;
56 }
57}
58use sealed::Sealed;
59
60pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
65
66macro_rules! impl_element {
67 { $ty:ty, $unsigned:ty } => {
68 impl Sealed for $ty {
69 #[inline]
70 fn valid<const N: usize>(value: Simd<Self, N>) -> bool
71 {
72 unsafe {
76 use core::intrinsics::simd;
77 let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
78 let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
79 let valid: Simd<Self, N> = simd::simd_or(falses, trues);
80 simd::simd_reduce_all(valid)
81 }
82 }
83
84 #[inline]
85 fn eq(self, other: Self) -> bool { self == other }
86
87 #[inline]
88 fn to_usize(self) -> usize {
89 self as usize
90 }
91
92 #[inline]
93 fn max_unsigned() -> u64 {
94 <$unsigned>::MAX as u64
95 }
96
97 type Unsigned = $unsigned;
98
99 const TRUE: Self = -1;
100 const FALSE: Self = 0;
101 }
102
103 unsafe impl MaskElement for $ty {}
105 }
106}
107
108impl_element! { i8, u8 }
109impl_element! { i16, u16 }
110impl_element! { i32, u32 }
111impl_element! { i64, u64 }
112impl_element! { isize, usize }
113
114#[repr(transparent)]
125pub struct Mask<T, const N: usize>(Simd<T, N>)
126where
127 T: MaskElement;
128
129impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
130
131impl<T, const N: usize> Clone for Mask<T, N>
132where
133 T: MaskElement,
134{
135 #[inline]
136 fn clone(&self) -> Self {
137 *self
138 }
139}
140
141impl<T, const N: usize> Mask<T, N>
142where
143 T: MaskElement,
144{
145 #[inline]
147 #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
148 pub const fn splat(value: bool) -> Self {
149 Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
150 }
151
152 #[inline]
154 pub fn from_array(array: [bool; N]) -> Self {
155 unsafe {
162 let bytes: [u8; N] = mem::transmute_copy(&array);
163 let bools: Simd<i8, N> =
164 core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
165 Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
166 }
167 }
168
169 #[inline]
171 pub fn to_array(self) -> [bool; N] {
172 unsafe {
183 let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
184 bytes &= Simd::splat(1i8);
185 mem::transmute_copy(&bytes)
186 }
187 }
188
189 #[inline]
195 #[must_use = "method returns a new mask and does not mutate the original value"]
196 pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
197 unsafe {
199 core::intrinsics::assume(<T as Sealed>::valid(value));
200 }
201 Self(value)
202 }
203
204 #[inline]
210 #[must_use = "method returns a new mask and does not mutate the original value"]
211 #[track_caller]
212 pub fn from_simd(value: Simd<T, N>) -> Self {
213 assert!(T::valid(value), "all values must be either 0 or -1",);
214 unsafe { Self::from_simd_unchecked(value) }
216 }
217
218 #[inline]
221 #[must_use = "method returns a new vector and does not mutate the original value"]
222 pub fn to_simd(self) -> Simd<T, N> {
223 self.0
224 }
225
226 #[inline]
228 #[must_use = "method returns a new mask and does not mutate the original value"]
229 pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
230 unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
232 }
233
234 #[inline]
239 #[must_use = "method returns a new bool and does not mutate the original value"]
240 pub unsafe fn test_unchecked(&self, index: usize) -> bool {
241 unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
243 }
244
245 #[inline]
250 #[must_use = "method returns a new bool and does not mutate the original value"]
251 #[track_caller]
252 pub fn test(&self, index: usize) -> bool {
253 T::eq(self.0[index], T::TRUE)
254 }
255
256 #[inline]
261 pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
262 unsafe {
264 *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
265 }
266 }
267
268 #[inline]
273 #[track_caller]
274 pub fn set(&mut self, index: usize, value: bool) {
275 self.0[index] = if value { T::TRUE } else { T::FALSE }
276 }
277
278 #[inline]
280 #[must_use = "method returns a new bool and does not mutate the original value"]
281 pub fn any(self) -> bool {
282 unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
284 }
285
286 #[inline]
288 #[must_use = "method returns a new bool and does not mutate the original value"]
289 pub fn all(self) -> bool {
290 unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
292 }
293
294 #[inline]
298 #[must_use = "method returns a new integer and does not mutate the original value"]
299 pub fn to_bitmask(self) -> u64 {
300 const {
301 assert!(N <= 64, "number of elements can't be greater than 64");
302 }
303
304 #[inline]
305 unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
306 mask: Mask<T, N>,
307 ) -> U
308 where
309 T: MaskElement,
310 {
311 let resized = mask.resize::<M>(false);
312
313 let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
315
316 bitmask.fix_endianness()
318 }
319
320 if N <= 8 {
322 unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
324 } else if N <= 16 {
325 unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
327 } else if N <= 32 {
328 unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
330 } else {
331 unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
333 }
334 }
335
336 #[inline]
341 #[must_use = "method returns a new mask and does not mutate the original value"]
342 pub fn from_bitmask(bitmask: u64) -> Self {
343 Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
344 }
345
346 #[inline]
360 #[must_use = "method returns the index and does not mutate the original value"]
361 pub fn first_set(self) -> Option<usize> {
362 if cfg!(target_feature = "sse") && N <= 64 {
364 let tz = self.to_bitmask().trailing_zeros();
365 return if tz == 64 { None } else { Some(tz as usize) };
366 }
367
368 let index = Simd::from_array(
375 const {
376 let mut index = [0; N];
377 let mut i = 0;
378 while i < N {
379 index[i] = i;
380 i += 1;
381 }
382 index
383 },
384 );
385
386 let index: Simd<T, N> = unsafe { core::intrinsics::simd::simd_cast(index) };
388
389 let masked_index = self.select(index, Self::splat(true).to_simd());
390
391 let masked_index: Simd<T::Unsigned, N> =
393 unsafe { core::intrinsics::simd::simd_cast(masked_index) };
394
395 let min_index: T::Unsigned =
397 unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
398
399 let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
401
402 if min_index.eq(T::TRUE) {
403 None
404 } else {
405 Some(min_index.to_usize())
406 }
407 }
408}
409
410impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
412where
413 T: MaskElement,
414{
415 #[inline]
416 fn from(array: [bool; N]) -> Self {
417 Self::from_array(array)
418 }
419}
420
421impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
422where
423 T: MaskElement,
424{
425 #[inline]
426 fn from(vector: Mask<T, N>) -> Self {
427 vector.to_array()
428 }
429}
430
431impl<T, const N: usize> Default for Mask<T, N>
432where
433 T: MaskElement,
434{
435 #[inline]
436 fn default() -> Self {
437 Self::splat(false)
438 }
439}
440
441impl<T, const N: usize> PartialEq for Mask<T, N>
442where
443 T: MaskElement + PartialEq,
444{
445 #[inline]
446 fn eq(&self, other: &Self) -> bool {
447 self.0 == other.0
448 }
449}
450
451impl<T, const N: usize> PartialOrd for Mask<T, N>
452where
453 T: MaskElement + PartialOrd,
454{
455 #[inline]
456 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
457 self.0.partial_cmp(&other.0)
458 }
459}
460
461impl<T, const N: usize> fmt::Debug for Mask<T, N>
462where
463 T: MaskElement + fmt::Debug,
464{
465 #[inline]
466 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467 f.debug_list()
468 .entries((0..N).map(|i| self.test(i)))
469 .finish()
470 }
471}
472
473impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
474where
475 T: MaskElement,
476{
477 type Output = Self;
478 #[inline]
479 fn bitand(self, rhs: Self) -> Self {
480 unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
482 }
483}
484
485impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
486where
487 T: MaskElement,
488{
489 type Output = Self;
490 #[inline]
491 fn bitand(self, rhs: bool) -> Self {
492 self & Self::splat(rhs)
493 }
494}
495
496impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
497where
498 T: MaskElement,
499{
500 type Output = Mask<T, N>;
501 #[inline]
502 fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
503 Mask::splat(self) & rhs
504 }
505}
506
507impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
508where
509 T: MaskElement,
510{
511 type Output = Self;
512 #[inline]
513 fn bitor(self, rhs: Self) -> Self {
514 unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
516 }
517}
518
519impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
520where
521 T: MaskElement,
522{
523 type Output = Self;
524 #[inline]
525 fn bitor(self, rhs: bool) -> Self {
526 self | Self::splat(rhs)
527 }
528}
529
530impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
531where
532 T: MaskElement,
533{
534 type Output = Mask<T, N>;
535 #[inline]
536 fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
537 Mask::splat(self) | rhs
538 }
539}
540
541impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
542where
543 T: MaskElement,
544{
545 type Output = Self;
546 #[inline]
547 fn bitxor(self, rhs: Self) -> Self::Output {
548 unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
550 }
551}
552
553impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
554where
555 T: MaskElement,
556{
557 type Output = Self;
558 #[inline]
559 fn bitxor(self, rhs: bool) -> Self::Output {
560 self ^ Self::splat(rhs)
561 }
562}
563
564impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
565where
566 T: MaskElement,
567{
568 type Output = Mask<T, N>;
569 #[inline]
570 fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
571 Mask::splat(self) ^ rhs
572 }
573}
574
575impl<T, const N: usize> core::ops::Not for Mask<T, N>
576where
577 T: MaskElement,
578{
579 type Output = Mask<T, N>;
580 #[inline]
581 fn not(self) -> Self::Output {
582 Self::splat(true) ^ self
583 }
584}
585
586impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
587where
588 T: MaskElement,
589{
590 #[inline]
591 fn bitand_assign(&mut self, rhs: Self) {
592 *self = *self & rhs;
593 }
594}
595
596impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
597where
598 T: MaskElement,
599{
600 #[inline]
601 fn bitand_assign(&mut self, rhs: bool) {
602 *self &= Self::splat(rhs);
603 }
604}
605
606impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
607where
608 T: MaskElement,
609{
610 #[inline]
611 fn bitor_assign(&mut self, rhs: Self) {
612 *self = *self | rhs;
613 }
614}
615
616impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
617where
618 T: MaskElement,
619{
620 #[inline]
621 fn bitor_assign(&mut self, rhs: bool) {
622 *self |= Self::splat(rhs);
623 }
624}
625
626impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
627where
628 T: MaskElement,
629{
630 #[inline]
631 fn bitxor_assign(&mut self, rhs: Self) {
632 *self = *self ^ rhs;
633 }
634}
635
636impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
637where
638 T: MaskElement,
639{
640 #[inline]
641 fn bitxor_assign(&mut self, rhs: bool) {
642 *self ^= Self::splat(rhs);
643 }
644}
645
646macro_rules! impl_from {
647 { $from:ty => $($to:ty),* } => {
648 $(
649 impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
650 {
651 #[inline]
652 fn from(value: Mask<$from, N>) -> Self {
653 value.cast()
654 }
655 }
656 )*
657 }
658}
659impl_from! { i8 => i16, i32, i64, isize }
660impl_from! { i16 => i32, i64, isize, i8 }
661impl_from! { i32 => i64, isize, i8, i16 }
662impl_from! { i64 => isize, i8, i16, i32 }
663impl_from! { isize => i8, i16, i32, i64 }