1#![allow(non_camel_case_types)]
4
5#[cfg_attr(
6 not(all(target_arch = "x86_64", target_feature = "avx512f")),
7 path = "masks/full_masks.rs"
8)]
9#[cfg_attr(
10 all(target_arch = "x86_64", target_feature = "avx512f"),
11 path = "masks/bitmask.rs"
12)]
13mod mask_impl;
14
15use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount};
16use core::cmp::Ordering;
17use core::{fmt, mem};
18
19mod sealed {
20 use super::*;
21
22 pub trait Sealed {
29 fn valid<const N: usize>(values: Simd<Self, N>) -> bool
30 where
31 LaneCount<N>: SupportedLaneCount,
32 Self: SimdElement;
33
34 fn eq(self, other: Self) -> bool;
35
36 fn to_usize(self) -> usize;
37 fn max_unsigned() -> u64;
38
39 type Unsigned: SimdElement;
40
41 const TRUE: Self;
42
43 const FALSE: Self;
44 }
45}
46use sealed::Sealed;
47
48pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
53
54macro_rules! impl_element {
55 { $ty:ty, $unsigned:ty } => {
56 impl Sealed for $ty {
57 #[inline]
58 fn valid<const N: usize>(value: Simd<Self, N>) -> bool
59 where
60 LaneCount<N>: SupportedLaneCount,
61 {
62 unsafe {
66 use core::intrinsics::simd;
67 let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
68 let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
69 let valid: Simd<Self, N> = simd::simd_or(falses, trues);
70 simd::simd_reduce_all(valid)
71 }
72 }
73
74 #[inline]
75 fn eq(self, other: Self) -> bool { self == other }
76
77 #[inline]
78 fn to_usize(self) -> usize {
79 self as usize
80 }
81
82 #[inline]
83 fn max_unsigned() -> u64 {
84 <$unsigned>::MAX as u64
85 }
86
87 type Unsigned = $unsigned;
88
89 const TRUE: Self = -1;
90 const FALSE: Self = 0;
91 }
92
93 unsafe impl MaskElement for $ty {}
95 }
96}
97
98impl_element! { i8, u8 }
99impl_element! { i16, u16 }
100impl_element! { i32, u32 }
101impl_element! { i64, u64 }
102impl_element! { isize, usize }
103
104#[repr(transparent)]
112pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
113where
114 T: MaskElement,
115 LaneCount<N>: SupportedLaneCount;
116
117impl<T, const N: usize> Copy for Mask<T, N>
118where
119 T: MaskElement,
120 LaneCount<N>: SupportedLaneCount,
121{
122}
123
124impl<T, const N: usize> Clone for Mask<T, N>
125where
126 T: MaskElement,
127 LaneCount<N>: SupportedLaneCount,
128{
129 #[inline]
130 fn clone(&self) -> Self {
131 *self
132 }
133}
134
135impl<T, const N: usize> Mask<T, N>
136where
137 T: MaskElement,
138 LaneCount<N>: SupportedLaneCount,
139{
140 #[inline]
142 #[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
143 pub const fn splat(value: bool) -> Self {
144 Self(mask_impl::Mask::splat(value))
145 }
146
147 #[inline]
149 pub fn from_array(array: [bool; N]) -> Self {
150 unsafe {
157 let bytes: [u8; N] = mem::transmute_copy(&array);
158 let bools: Simd<i8, N> =
159 core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
160 Mask::from_int_unchecked(core::intrinsics::simd::simd_cast(bools))
161 }
162 }
163
164 #[inline]
166 pub fn to_array(self) -> [bool; N] {
167 unsafe {
178 let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_int());
179 bytes &= Simd::splat(1i8);
180 mem::transmute_copy(&bytes)
181 }
182 }
183
184 #[inline]
190 #[must_use = "method returns a new mask and does not mutate the original value"]
191 pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
192 unsafe {
194 core::intrinsics::assume(<T as Sealed>::valid(value));
195 Self(mask_impl::Mask::from_int_unchecked(value))
196 }
197 }
198
199 #[inline]
205 #[must_use = "method returns a new mask and does not mutate the original value"]
206 #[track_caller]
207 pub fn from_int(value: Simd<T, N>) -> Self {
208 assert!(T::valid(value), "all values must be either 0 or -1",);
209 unsafe { Self::from_int_unchecked(value) }
211 }
212
213 #[inline]
216 #[must_use = "method returns a new vector and does not mutate the original value"]
217 pub fn to_int(self) -> Simd<T, N> {
218 self.0.to_int()
219 }
220
221 #[inline]
223 #[must_use = "method returns a new mask and does not mutate the original value"]
224 pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
225 Mask(self.0.convert())
226 }
227
228 #[inline]
233 #[must_use = "method returns a new bool and does not mutate the original value"]
234 pub unsafe fn test_unchecked(&self, index: usize) -> bool {
235 unsafe { self.0.test_unchecked(index) }
237 }
238
239 #[inline]
244 #[must_use = "method returns a new bool and does not mutate the original value"]
245 #[track_caller]
246 pub fn test(&self, index: usize) -> bool {
247 assert!(index < N, "element index out of range");
248 unsafe { self.test_unchecked(index) }
250 }
251
252 #[inline]
257 pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
258 unsafe {
260 self.0.set_unchecked(index, value);
261 }
262 }
263
264 #[inline]
269 #[track_caller]
270 pub fn set(&mut self, index: usize, value: bool) {
271 assert!(index < N, "element index out of range");
272 unsafe {
274 self.set_unchecked(index, value);
275 }
276 }
277
278 #[inline]
280 #[must_use = "method returns a new bool and does not mutate the original value"]
281 pub fn any(self) -> bool {
282 self.0.any()
283 }
284
285 #[inline]
287 #[must_use = "method returns a new bool and does not mutate the original value"]
288 pub fn all(self) -> bool {
289 self.0.all()
290 }
291
292 #[inline]
297 #[must_use = "method returns a new integer and does not mutate the original value"]
298 pub fn to_bitmask(self) -> u64 {
299 self.0.to_bitmask_integer()
300 }
301
302 #[inline]
307 #[must_use = "method returns a new mask and does not mutate the original value"]
308 pub fn from_bitmask(bitmask: u64) -> Self {
309 Self(mask_impl::Mask::from_bitmask_integer(bitmask))
310 }
311
312 #[inline]
326 #[must_use = "method returns the index and does not mutate the original value"]
327 pub fn first_set(self) -> Option<usize> {
328 if cfg!(target_feature = "sse") && N <= 64 {
330 let tz = self.to_bitmask().trailing_zeros();
331 return if tz == 64 { None } else { Some(tz as usize) };
332 }
333
334 let index = Simd::from_array(
341 const {
342 let mut index = [0; N];
343 let mut i = 0;
344 while i < N {
345 index[i] = i;
346 i += 1;
347 }
348 index
349 },
350 );
351
352 let index: Simd<T, N> = unsafe { core::intrinsics::simd::simd_cast(index) };
354
355 let masked_index = self.select(index, Self::splat(true).to_int());
356
357 let masked_index: Simd<T::Unsigned, N> =
359 unsafe { core::intrinsics::simd::simd_cast(masked_index) };
360
361 let min_index: T::Unsigned =
363 unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
364
365 let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
367
368 if min_index.eq(T::TRUE) {
369 None
370 } else {
371 Some(min_index.to_usize())
372 }
373 }
374}
375
376impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
378where
379 T: MaskElement,
380 LaneCount<N>: SupportedLaneCount,
381{
382 #[inline]
383 fn from(array: [bool; N]) -> Self {
384 Self::from_array(array)
385 }
386}
387
388impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
389where
390 T: MaskElement,
391 LaneCount<N>: SupportedLaneCount,
392{
393 #[inline]
394 fn from(vector: Mask<T, N>) -> Self {
395 vector.to_array()
396 }
397}
398
399impl<T, const N: usize> Default for Mask<T, N>
400where
401 T: MaskElement,
402 LaneCount<N>: SupportedLaneCount,
403{
404 #[inline]
405 fn default() -> Self {
406 Self::splat(false)
407 }
408}
409
410impl<T, const N: usize> PartialEq for Mask<T, N>
411where
412 T: MaskElement + PartialEq,
413 LaneCount<N>: SupportedLaneCount,
414{
415 #[inline]
416 fn eq(&self, other: &Self) -> bool {
417 self.0 == other.0
418 }
419}
420
421impl<T, const N: usize> PartialOrd for Mask<T, N>
422where
423 T: MaskElement + PartialOrd,
424 LaneCount<N>: SupportedLaneCount,
425{
426 #[inline]
427 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
428 self.0.partial_cmp(&other.0)
429 }
430}
431
432impl<T, const N: usize> fmt::Debug for Mask<T, N>
433where
434 T: MaskElement + fmt::Debug,
435 LaneCount<N>: SupportedLaneCount,
436{
437 #[inline]
438 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439 f.debug_list()
440 .entries((0..N).map(|i| self.test(i)))
441 .finish()
442 }
443}
444
445impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
446where
447 T: MaskElement,
448 LaneCount<N>: SupportedLaneCount,
449{
450 type Output = Self;
451 #[inline]
452 fn bitand(self, rhs: Self) -> Self {
453 Self(self.0 & rhs.0)
454 }
455}
456
457impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
458where
459 T: MaskElement,
460 LaneCount<N>: SupportedLaneCount,
461{
462 type Output = Self;
463 #[inline]
464 fn bitand(self, rhs: bool) -> Self {
465 self & Self::splat(rhs)
466 }
467}
468
469impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
470where
471 T: MaskElement,
472 LaneCount<N>: SupportedLaneCount,
473{
474 type Output = Mask<T, N>;
475 #[inline]
476 fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
477 Mask::splat(self) & rhs
478 }
479}
480
481impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
482where
483 T: MaskElement,
484 LaneCount<N>: SupportedLaneCount,
485{
486 type Output = Self;
487 #[inline]
488 fn bitor(self, rhs: Self) -> Self {
489 Self(self.0 | rhs.0)
490 }
491}
492
493impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
494where
495 T: MaskElement,
496 LaneCount<N>: SupportedLaneCount,
497{
498 type Output = Self;
499 #[inline]
500 fn bitor(self, rhs: bool) -> Self {
501 self | Self::splat(rhs)
502 }
503}
504
505impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
506where
507 T: MaskElement,
508 LaneCount<N>: SupportedLaneCount,
509{
510 type Output = Mask<T, N>;
511 #[inline]
512 fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
513 Mask::splat(self) | rhs
514 }
515}
516
517impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
518where
519 T: MaskElement,
520 LaneCount<N>: SupportedLaneCount,
521{
522 type Output = Self;
523 #[inline]
524 fn bitxor(self, rhs: Self) -> Self::Output {
525 Self(self.0 ^ rhs.0)
526 }
527}
528
529impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
530where
531 T: MaskElement,
532 LaneCount<N>: SupportedLaneCount,
533{
534 type Output = Self;
535 #[inline]
536 fn bitxor(self, rhs: bool) -> Self::Output {
537 self ^ Self::splat(rhs)
538 }
539}
540
541impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
542where
543 T: MaskElement,
544 LaneCount<N>: SupportedLaneCount,
545{
546 type Output = Mask<T, N>;
547 #[inline]
548 fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
549 Mask::splat(self) ^ rhs
550 }
551}
552
553impl<T, const N: usize> core::ops::Not for Mask<T, N>
554where
555 T: MaskElement,
556 LaneCount<N>: SupportedLaneCount,
557{
558 type Output = Mask<T, N>;
559 #[inline]
560 fn not(self) -> Self::Output {
561 Self(!self.0)
562 }
563}
564
565impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
566where
567 T: MaskElement,
568 LaneCount<N>: SupportedLaneCount,
569{
570 #[inline]
571 fn bitand_assign(&mut self, rhs: Self) {
572 self.0 = self.0 & rhs.0;
573 }
574}
575
576impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
577where
578 T: MaskElement,
579 LaneCount<N>: SupportedLaneCount,
580{
581 #[inline]
582 fn bitand_assign(&mut self, rhs: bool) {
583 *self &= Self::splat(rhs);
584 }
585}
586
587impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
588where
589 T: MaskElement,
590 LaneCount<N>: SupportedLaneCount,
591{
592 #[inline]
593 fn bitor_assign(&mut self, rhs: Self) {
594 self.0 = self.0 | rhs.0;
595 }
596}
597
598impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
599where
600 T: MaskElement,
601 LaneCount<N>: SupportedLaneCount,
602{
603 #[inline]
604 fn bitor_assign(&mut self, rhs: bool) {
605 *self |= Self::splat(rhs);
606 }
607}
608
609impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
610where
611 T: MaskElement,
612 LaneCount<N>: SupportedLaneCount,
613{
614 #[inline]
615 fn bitxor_assign(&mut self, rhs: Self) {
616 self.0 = self.0 ^ rhs.0;
617 }
618}
619
620impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
621where
622 T: MaskElement,
623 LaneCount<N>: SupportedLaneCount,
624{
625 #[inline]
626 fn bitxor_assign(&mut self, rhs: bool) {
627 *self ^= Self::splat(rhs);
628 }
629}
630
631macro_rules! impl_from {
632 { $from:ty => $($to:ty),* } => {
633 $(
634 impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
635 where
636 LaneCount<N>: SupportedLaneCount,
637 {
638 #[inline]
639 fn from(value: Mask<$from, N>) -> Self {
640 value.cast()
641 }
642 }
643 )*
644 }
645}
646impl_from! { i8 => i16, i32, i64, isize }
647impl_from! { i16 => i32, i64, isize, i8 }
648impl_from! { i32 => i64, isize, i8, i16 }
649impl_from! { i64 => isize, i8, i16, i32 }
650impl_from! { isize => i8, i16, i32, i64 }