core_simd/simd/cmp/
ord.rs

1use crate::simd::{
2    Mask, Select, Simd,
3    cmp::SimdPartialEq,
4    ptr::{SimdConstPtr, SimdMutPtr},
5};
6
7/// Parallel `PartialOrd`.
8pub trait SimdPartialOrd: SimdPartialEq {
9    /// Test if each element is less than the corresponding element in `other`.
10    #[must_use = "method returns a new mask and does not mutate the original value"]
11    fn simd_lt(self, other: Self) -> Self::Mask;
12
13    /// Test if each element is less than or equal to the corresponding element in `other`.
14    #[must_use = "method returns a new mask and does not mutate the original value"]
15    fn simd_le(self, other: Self) -> Self::Mask;
16
17    /// Test if each element is greater than the corresponding element in `other`.
18    #[must_use = "method returns a new mask and does not mutate the original value"]
19    fn simd_gt(self, other: Self) -> Self::Mask;
20
21    /// Test if each element is greater than or equal to the corresponding element in `other`.
22    #[must_use = "method returns a new mask and does not mutate the original value"]
23    fn simd_ge(self, other: Self) -> Self::Mask;
24}
25
26/// Parallel `Ord`.
27pub trait SimdOrd: SimdPartialOrd {
28    /// Returns the element-wise maximum with `other`.
29    #[must_use = "method returns a new vector and does not mutate the original value"]
30    fn simd_max(self, other: Self) -> Self;
31
32    /// Returns the element-wise minimum with `other`.
33    #[must_use = "method returns a new vector and does not mutate the original value"]
34    fn simd_min(self, other: Self) -> Self;
35
36    /// Restrict each element to a certain interval.
37    ///
38    /// For each element, returns `max` if `self` is greater than `max`, and `min` if `self` is
39    /// less than `min`. Otherwise returns `self`.
40    ///
41    /// # Panics
42    ///
43    /// Panics if `min > max` on any element.
44    #[must_use = "method returns a new vector and does not mutate the original value"]
45    fn simd_clamp(self, min: Self, max: Self) -> Self;
46}
47
48macro_rules! impl_integer {
49    { $($integer:ty),* } => {
50        $(
51        impl<const N: usize> SimdPartialOrd for Simd<$integer, N>
52        {
53            #[inline]
54            fn simd_lt(self, other: Self) -> Self::Mask {
55                // Safety: `self` is a vector, and the result of the comparison
56                // is always a valid mask.
57                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_lt(self, other)) }
58            }
59
60            #[inline]
61            fn simd_le(self, other: Self) -> Self::Mask {
62                // Safety: `self` is a vector, and the result of the comparison
63                // is always a valid mask.
64                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_le(self, other)) }
65            }
66
67            #[inline]
68            fn simd_gt(self, other: Self) -> Self::Mask {
69                // Safety: `self` is a vector, and the result of the comparison
70                // is always a valid mask.
71                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_gt(self, other)) }
72            }
73
74            #[inline]
75            fn simd_ge(self, other: Self) -> Self::Mask {
76                // Safety: `self` is a vector, and the result of the comparison
77                // is always a valid mask.
78                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_ge(self, other)) }
79            }
80        }
81
82        impl<const N: usize> SimdOrd for Simd<$integer, N>
83        {
84            #[inline]
85            fn simd_max(self, other: Self) -> Self {
86                self.simd_lt(other).select(other, self)
87            }
88
89            #[inline]
90            fn simd_min(self, other: Self) -> Self {
91                self.simd_gt(other).select(other, self)
92            }
93
94            #[inline]
95            #[track_caller]
96            fn simd_clamp(self, min: Self, max: Self) -> Self {
97                assert!(
98                    min.simd_le(max).all(),
99                    "each element in `min` must be less than or equal to the corresponding element in `max`",
100                );
101                self.simd_max(min).simd_min(max)
102            }
103        }
104        )*
105    }
106}
107
108impl_integer! { u8, u16, u32, u64, usize, i8, i16, i32, i64, isize }
109
110macro_rules! impl_float {
111    { $($float:ty),* } => {
112        $(
113        impl<const N: usize> SimdPartialOrd for Simd<$float, N>
114        {
115            #[inline]
116            fn simd_lt(self, other: Self) -> Self::Mask {
117                // Safety: `self` is a vector, and the result of the comparison
118                // is always a valid mask.
119                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_lt(self, other)) }
120            }
121
122            #[inline]
123            fn simd_le(self, other: Self) -> Self::Mask {
124                // Safety: `self` is a vector, and the result of the comparison
125                // is always a valid mask.
126                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_le(self, other)) }
127            }
128
129            #[inline]
130            fn simd_gt(self, other: Self) -> Self::Mask {
131                // Safety: `self` is a vector, and the result of the comparison
132                // is always a valid mask.
133                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_gt(self, other)) }
134            }
135
136            #[inline]
137            fn simd_ge(self, other: Self) -> Self::Mask {
138                // Safety: `self` is a vector, and the result of the comparison
139                // is always a valid mask.
140                unsafe { Mask::from_simd_unchecked(core::intrinsics::simd::simd_ge(self, other)) }
141            }
142        }
143        )*
144    }
145}
146
147impl_float! { f32, f64 }
148
149macro_rules! impl_mask {
150    { $($integer:ty),* } => {
151        $(
152        impl<const N: usize> SimdPartialOrd for Mask<$integer, N>
153        {
154            #[inline]
155            fn simd_lt(self, other: Self) -> Self::Mask {
156                // Safety: `self` is a vector, and the result of the comparison
157                // is always a valid mask.
158                unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_lt(self.to_simd(), other.to_simd())) }
159            }
160
161            #[inline]
162            fn simd_le(self, other: Self) -> Self::Mask {
163                // Safety: `self` is a vector, and the result of the comparison
164                // is always a valid mask.
165                unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_le(self.to_simd(), other.to_simd())) }
166            }
167
168            #[inline]
169            fn simd_gt(self, other: Self) -> Self::Mask {
170                // Safety: `self` is a vector, and the result of the comparison
171                // is always a valid mask.
172                unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_gt(self.to_simd(), other.to_simd())) }
173            }
174
175            #[inline]
176            fn simd_ge(self, other: Self) -> Self::Mask {
177                // Safety: `self` is a vector, and the result of the comparison
178                // is always a valid mask.
179                unsafe { Self::from_simd_unchecked(core::intrinsics::simd::simd_ge(self.to_simd(), other.to_simd())) }
180            }
181        }
182
183        impl<const N: usize> SimdOrd for Mask<$integer, N>
184        {
185            #[inline]
186            fn simd_max(self, other: Self) -> Self {
187                self.simd_gt(other).select(other, self)
188            }
189
190            #[inline]
191            fn simd_min(self, other: Self) -> Self {
192                self.simd_lt(other).select(other, self)
193            }
194
195            #[inline]
196            #[track_caller]
197            fn simd_clamp(self, min: Self, max: Self) -> Self {
198                assert!(
199                    min.simd_le(max).all(),
200                    "each element in `min` must be less than or equal to the corresponding element in `max`",
201                );
202                self.simd_max(min).simd_min(max)
203            }
204        }
205        )*
206    }
207}
208
209impl_mask! { i8, i16, i32, i64, isize }
210
211impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N> {
212    #[inline]
213    fn simd_lt(self, other: Self) -> Self::Mask {
214        self.addr().simd_lt(other.addr())
215    }
216
217    #[inline]
218    fn simd_le(self, other: Self) -> Self::Mask {
219        self.addr().simd_le(other.addr())
220    }
221
222    #[inline]
223    fn simd_gt(self, other: Self) -> Self::Mask {
224        self.addr().simd_gt(other.addr())
225    }
226
227    #[inline]
228    fn simd_ge(self, other: Self) -> Self::Mask {
229        self.addr().simd_ge(other.addr())
230    }
231}
232
233impl<T, const N: usize> SimdOrd for Simd<*const T, N> {
234    #[inline]
235    fn simd_max(self, other: Self) -> Self {
236        self.simd_lt(other).select(other, self)
237    }
238
239    #[inline]
240    fn simd_min(self, other: Self) -> Self {
241        self.simd_gt(other).select(other, self)
242    }
243
244    #[inline]
245    #[track_caller]
246    fn simd_clamp(self, min: Self, max: Self) -> Self {
247        assert!(
248            min.simd_le(max).all(),
249            "each element in `min` must be less than or equal to the corresponding element in `max`",
250        );
251        self.simd_max(min).simd_min(max)
252    }
253}
254
255impl<T, const N: usize> SimdPartialOrd for Simd<*mut T, N> {
256    #[inline]
257    fn simd_lt(self, other: Self) -> Self::Mask {
258        self.addr().simd_lt(other.addr())
259    }
260
261    #[inline]
262    fn simd_le(self, other: Self) -> Self::Mask {
263        self.addr().simd_le(other.addr())
264    }
265
266    #[inline]
267    fn simd_gt(self, other: Self) -> Self::Mask {
268        self.addr().simd_gt(other.addr())
269    }
270
271    #[inline]
272    fn simd_ge(self, other: Self) -> Self::Mask {
273        self.addr().simd_ge(other.addr())
274    }
275}
276
277impl<T, const N: usize> SimdOrd for Simd<*mut T, N> {
278    #[inline]
279    fn simd_max(self, other: Self) -> Self {
280        self.simd_lt(other).select(other, self)
281    }
282
283    #[inline]
284    fn simd_min(self, other: Self) -> Self {
285        self.simd_gt(other).select(other, self)
286    }
287
288    #[inline]
289    #[track_caller]
290    fn simd_clamp(self, min: Self, max: Self) -> Self {
291        assert!(
292            min.simd_le(max).all(),
293            "each element in `min` must be less than or equal to the corresponding element in `max`",
294        );
295        self.simd_max(min).simd_min(max)
296    }
297}