1#![allow(non_camel_case_types)]
4
5use crate::simd::{Select, Simd, SimdCast, SimdElement};
6use core::cmp::Ordering;
7use core::{fmt, mem};
8
9pub(crate) trait FixEndianness {
10 fn fix_endianness(self) -> Self;
11}
12
13macro_rules! impl_fix_endianness {
14 { $($int:ty),* } => {
15 $(
16 impl FixEndianness for $int {
17 #[inline(always)]
18 fn fix_endianness(self) -> Self {
19 if cfg!(target_endian = "big") {
20 <$int>::reverse_bits(self)
21 } else {
22 self
23 }
24 }
25 }
26 )*
27 }
28}
29
30impl_fix_endianness! { u8, u16, u32, u64 }
31
32mod sealed {
33 use super::*;
34
35 pub trait Sealed {
42 fn valid<const N: usize>(values: Simd<Self, N>) -> bool
43 where
44 Self: SimdElement;
45
46 fn eq(self, other: Self) -> bool;
47
48 fn to_usize(self) -> usize;
49 fn max_unsigned() -> u64;
50
51 type Unsigned: SimdElement;
52
53 const TRUE: Self;
54
55 const FALSE: Self;
56 }
57}
58use sealed::Sealed;
59
60pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
65
66macro_rules! impl_element {
67 { $ty:ty, $unsigned:ty } => {
68 impl Sealed for $ty {
69 #[inline]
70 fn valid<const N: usize>(value: Simd<Self, N>) -> bool
71 {
72 unsafe {
76 use core::intrinsics::simd;
77 let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
78 let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
79 let valid: Simd<Self, N> = simd::simd_or(falses, trues);
80 simd::simd_reduce_all(valid)
81 }
82 }
83
84 #[inline]
85 fn eq(self, other: Self) -> bool { self == other }
86
87 #[inline]
88 fn to_usize(self) -> usize {
89 self as usize
90 }
91
92 #[inline]
93 fn max_unsigned() -> u64 {
94 <$unsigned>::MAX as u64
95 }
96
97 type Unsigned = $unsigned;
98
99 const TRUE: Self = -1;
100 const FALSE: Self = 0;
101 }
102
103 unsafe impl MaskElement for $ty {}
105 }
106}
107
108impl_element! { i8, u8 }
109impl_element! { i16, u16 }
110impl_element! { i32, u32 }
111impl_element! { i64, u64 }
112impl_element! { isize, usize }
113
114#[repr(transparent)]
125pub struct Mask<T, const N: usize>(Simd<T, N>)
126where
127 T: MaskElement;
128
129impl<T, const N: usize> Copy for Mask<T, N> where T: MaskElement {}
130
131impl<T, const N: usize> Clone for Mask<T, N>
132where
133 T: MaskElement,
134{
135 #[inline]
136 fn clone(&self) -> Self {
137 *self
138 }
139}
140
141impl<T, const N: usize> Mask<T, N>
142where
143 T: MaskElement,
144{
145 #[inline]
147 #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
148 pub const fn splat(value: bool) -> Self {
149 Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
150 }
151
152 #[inline]
154 pub fn from_array(array: [bool; N]) -> Self {
155 unsafe {
162 let bytes: [u8; N] = mem::transmute_copy(&array);
163 let bools: Simd<i8, N> =
164 core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
165 Mask::from_simd_unchecked(core::intrinsics::simd::simd_cast(bools))
166 }
167 }
168
169 #[inline]
171 pub fn to_array(self) -> [bool; N] {
172 unsafe {
183 let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_simd());
184 bytes &= Simd::splat(1i8);
185 mem::transmute_copy(&bytes)
186 }
187 }
188
189 #[inline]
195 #[must_use = "method returns a new mask and does not mutate the original value"]
196 pub unsafe fn from_simd_unchecked(value: Simd<T, N>) -> Self {
197 unsafe {
199 core::intrinsics::assume(<T as Sealed>::valid(value));
200 }
201 Self(value)
202 }
203
204 #[inline]
210 #[must_use = "method returns a new mask and does not mutate the original value"]
211 #[track_caller]
212 pub fn from_simd(value: Simd<T, N>) -> Self {
213 assert!(T::valid(value), "all values must be either 0 or -1",);
214 unsafe { Self::from_simd_unchecked(value) }
216 }
217
218 #[inline]
221 #[must_use = "method returns a new vector and does not mutate the original value"]
222 pub fn to_simd(self) -> Simd<T, N> {
223 self.0
224 }
225
226 #[inline]
228 #[must_use = "method returns a new mask and does not mutate the original value"]
229 pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
230 unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
232 }
233
234 #[inline]
239 #[must_use = "method returns a new bool and does not mutate the original value"]
240 pub unsafe fn test_unchecked(&self, index: usize) -> bool {
241 unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
243 }
244
245 #[inline]
250 #[must_use = "method returns a new bool and does not mutate the original value"]
251 #[track_caller]
252 pub fn test(&self, index: usize) -> bool {
253 T::eq(self.0[index], T::TRUE)
254 }
255
256 #[inline]
261 pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
262 unsafe {
264 *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
265 }
266 }
267
268 #[inline]
273 #[track_caller]
274 pub fn set(&mut self, index: usize, value: bool) {
275 self.0[index] = if value { T::TRUE } else { T::FALSE }
276 }
277
278 #[inline]
280 #[must_use = "method returns a new bool and does not mutate the original value"]
281 pub fn any(self) -> bool {
282 unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
284 }
285
286 #[inline]
288 #[must_use = "method returns a new bool and does not mutate the original value"]
289 pub fn all(self) -> bool {
290 unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
292 }
293
294 #[inline]
298 #[must_use = "method returns a new integer and does not mutate the original value"]
299 pub fn to_bitmask(self) -> u64 {
300 const {
301 assert!(N <= 64, "number of elements can't be greater than 64");
302 }
303
304 #[inline]
305 unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
306 mask: Mask<T, N>,
307 ) -> U
308 where
309 T: MaskElement,
310 {
311 let resized = mask.resize::<M>(false);
312
313 let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };
315
316 bitmask.fix_endianness()
318 }
319
320 if N <= 8 {
322 unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
324 } else if N <= 16 {
325 unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
327 } else if N <= 32 {
328 unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
330 } else {
331 unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
333 }
334 }
335
336 #[inline]
341 #[must_use = "method returns a new mask and does not mutate the original value"]
342 pub fn from_bitmask(bitmask: u64) -> Self {
343 Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
344 }
345
346 #[inline]
360 #[must_use = "method returns the index and does not mutate the original value"]
361 pub fn first_set(self) -> Option<usize> {
362 if cfg!(target_feature = "sse") && N <= 64 {
364 let tz = self.to_bitmask().trailing_zeros();
365 return if tz == 64 { None } else { Some(tz as usize) };
366 }
367
368 let index: Simd<T, N> = const {
375 let mut index = [0; N];
376 let mut i = 0;
377 while i < N {
378 index[i] = i;
379 i += 1;
380 }
381 unsafe { core::intrinsics::simd::simd_cast(Simd::from_array(index)) }
383 };
384
385 let masked_index: Simd<T, N> =
387 unsafe { core::intrinsics::simd::simd_or((!self).to_simd(), index) };
388
389 let masked_index: Simd<T::Unsigned, N> =
391 unsafe { core::intrinsics::simd::simd_cast(masked_index) };
392
393 let min_index: T::Unsigned =
395 unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
396
397 let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
399
400 if min_index.eq(T::TRUE) {
401 None
402 } else {
403 Some(min_index.to_usize())
404 }
405 }
406}
407
408impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
410where
411 T: MaskElement,
412{
413 #[inline]
414 fn from(array: [bool; N]) -> Self {
415 Self::from_array(array)
416 }
417}
418
419impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
420where
421 T: MaskElement,
422{
423 #[inline]
424 fn from(vector: Mask<T, N>) -> Self {
425 vector.to_array()
426 }
427}
428
429impl<T, const N: usize> Default for Mask<T, N>
430where
431 T: MaskElement,
432{
433 #[inline]
434 fn default() -> Self {
435 Self::splat(false)
436 }
437}
438
439impl<T, const N: usize> PartialEq for Mask<T, N>
440where
441 T: MaskElement + PartialEq,
442{
443 #[inline]
444 fn eq(&self, other: &Self) -> bool {
445 self.0 == other.0
446 }
447}
448
449impl<T, const N: usize> PartialOrd for Mask<T, N>
450where
451 T: MaskElement + PartialOrd,
452{
453 #[inline]
454 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
455 self.0.partial_cmp(&other.0)
456 }
457}
458
459impl<T, const N: usize> fmt::Debug for Mask<T, N>
460where
461 T: MaskElement + fmt::Debug,
462{
463 #[inline]
464 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
465 f.debug_list()
466 .entries((0..N).map(|i| self.test(i)))
467 .finish()
468 }
469}
470
471impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
472where
473 T: MaskElement,
474{
475 type Output = Self;
476 #[inline]
477 fn bitand(self, rhs: Self) -> Self {
478 unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
480 }
481}
482
483impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
484where
485 T: MaskElement,
486{
487 type Output = Self;
488 #[inline]
489 fn bitand(self, rhs: bool) -> Self {
490 self & Self::splat(rhs)
491 }
492}
493
494impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
495where
496 T: MaskElement,
497{
498 type Output = Mask<T, N>;
499 #[inline]
500 fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
501 Mask::splat(self) & rhs
502 }
503}
504
505impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
506where
507 T: MaskElement,
508{
509 type Output = Self;
510 #[inline]
511 fn bitor(self, rhs: Self) -> Self {
512 unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
514 }
515}
516
517impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
518where
519 T: MaskElement,
520{
521 type Output = Self;
522 #[inline]
523 fn bitor(self, rhs: bool) -> Self {
524 self | Self::splat(rhs)
525 }
526}
527
528impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
529where
530 T: MaskElement,
531{
532 type Output = Mask<T, N>;
533 #[inline]
534 fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
535 Mask::splat(self) | rhs
536 }
537}
538
539impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
540where
541 T: MaskElement,
542{
543 type Output = Self;
544 #[inline]
545 fn bitxor(self, rhs: Self) -> Self::Output {
546 unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
548 }
549}
550
551impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
552where
553 T: MaskElement,
554{
555 type Output = Self;
556 #[inline]
557 fn bitxor(self, rhs: bool) -> Self::Output {
558 self ^ Self::splat(rhs)
559 }
560}
561
562impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
563where
564 T: MaskElement,
565{
566 type Output = Mask<T, N>;
567 #[inline]
568 fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
569 Mask::splat(self) ^ rhs
570 }
571}
572
573impl<T, const N: usize> core::ops::Not for Mask<T, N>
574where
575 T: MaskElement,
576{
577 type Output = Mask<T, N>;
578 #[inline]
579 fn not(self) -> Self::Output {
580 Self::splat(true) ^ self
581 }
582}
583
584impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
585where
586 T: MaskElement,
587{
588 #[inline]
589 fn bitand_assign(&mut self, rhs: Self) {
590 *self = *self & rhs;
591 }
592}
593
594impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
595where
596 T: MaskElement,
597{
598 #[inline]
599 fn bitand_assign(&mut self, rhs: bool) {
600 *self &= Self::splat(rhs);
601 }
602}
603
604impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
605where
606 T: MaskElement,
607{
608 #[inline]
609 fn bitor_assign(&mut self, rhs: Self) {
610 *self = *self | rhs;
611 }
612}
613
614impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
615where
616 T: MaskElement,
617{
618 #[inline]
619 fn bitor_assign(&mut self, rhs: bool) {
620 *self |= Self::splat(rhs);
621 }
622}
623
624impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
625where
626 T: MaskElement,
627{
628 #[inline]
629 fn bitxor_assign(&mut self, rhs: Self) {
630 *self = *self ^ rhs;
631 }
632}
633
634impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
635where
636 T: MaskElement,
637{
638 #[inline]
639 fn bitxor_assign(&mut self, rhs: bool) {
640 *self ^= Self::splat(rhs);
641 }
642}
643
644macro_rules! impl_from {
645 { $from:ty => $($to:ty),* } => {
646 $(
647 impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
648 {
649 #[inline]
650 fn from(value: Mask<$from, N>) -> Self {
651 value.cast()
652 }
653 }
654 )*
655 }
656}
657impl_from! { i8 => i16, i32, i64, isize }
658impl_from! { i16 => i32, i64, isize, i8 }
659impl_from! { i32 => i64, isize, i8, i16 }
660impl_from! { i64 => isize, i8, i16, i32 }
661impl_from! { isize => i8, i16, i32, i64 }