Skip to content

Commit 875b31c

Browse files
committed
Implement reductions
1 parent 926cf3a commit 875b31c

File tree

8 files changed

+288
-117
lines changed

8 files changed

+288
-117
lines changed

crates/core_simd/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ mod first;
1111
mod permute;
1212
#[macro_use]
1313
mod transmute;
14+
#[macro_use]
15+
mod reduction;
1416

1517
mod comparisons;
1618
mod fmt;

crates/core_simd/src/masks/bitmask.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::LanesAtMost32;
33
/// A mask where each lane is represented by a single bit.
44
#[derive(Copy, Clone, Debug)]
55
#[repr(transparent)]
6-
pub struct BitMask<const LANES: usize>(u64)
6+
pub struct BitMask<const LANES: usize>(pub(crate) u64)
77
where
88
BitMask<LANES>: LanesAtMost32;
99

crates/core_simd/src/masks/full_masks.rs

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,41 @@ impl core::fmt::Display for TryFromMaskError {
1414
}
1515

1616
macro_rules! define_mask {
17-
{ $(#[$attr:meta])* struct $name:ident<const $lanes:ident: usize>($type:ty); } => {
17+
{
18+
$(#[$attr:meta])*
19+
struct $name:ident<const $lanes:ident: usize>(
20+
crate::$type:ident<$lanes2:ident>
21+
);
22+
} => {
1823
$(#[$attr])*
1924
#[derive(Default, PartialEq, PartialOrd, Eq, Ord, Hash)]
2025
#[repr(transparent)]
21-
pub struct $name<const $lanes: usize>($type)
26+
pub struct $name<const $lanes: usize>(crate::$type<$lanes2>)
2227
where
23-
$type: crate::LanesAtMost32;
28+
crate::$type<LANES>: crate::LanesAtMost32;
2429

2530
impl<const LANES: usize> Copy for $name<LANES>
2631
where
27-
$type: crate::LanesAtMost32,
32+
crate::$type<LANES>: crate::LanesAtMost32,
2833
{}
2934

3035
impl<const LANES: usize> Clone for $name<LANES>
3136
where
32-
$type: crate::LanesAtMost32,
37+
crate::$type<LANES>: crate::LanesAtMost32,
3338
{
3439
#[inline]
3540
fn clone(&self) -> Self {
3641
*self
3742
}
3843
}
3944

40-
impl<const $lanes: usize> $name<$lanes>
45+
impl<const LANES: usize> $name<LANES>
4146
where
42-
$type: crate::LanesAtMost32,
47+
crate::$type<LANES>: crate::LanesAtMost32,
4348
{
4449
/// Construct a mask by setting all lanes to the given value.
4550
pub fn splat(value: bool) -> Self {
46-
Self(<$type>::splat(
51+
Self(<crate::$type<LANES>>::splat(
4752
if value {
4853
-1
4954
} else {
@@ -76,64 +81,73 @@ macro_rules! define_mask {
7681
}
7782
}
7883

79-
/// Creates a mask from an integer vector.
84+
/// Converts the mask to the equivalent integer representation, where -1 represents
85+
/// "set" and 0 represents "unset".
86+
#[inline]
87+
pub fn to_int(self) -> crate::$type<LANES> {
88+
self.0
89+
}
90+
91+
/// Creates a mask from the equivalent integer representation, where -1 represents
92+
/// "set" and 0 represents "unset".
8093
///
81-
/// # Safety
82-
/// All lanes must be either 0 or -1.
94+
/// Each provided lane must be either 0 or -1.
8395
#[inline]
84-
pub unsafe fn from_int_unchecked(value: $type) -> Self {
96+
pub unsafe fn from_int_unchecked(value: crate::$type<LANES>) -> Self {
8597
Self(value)
8698
}
8799

88-
/// Creates a mask from an integer vector.
100+
/// Creates a mask from the equivalent integer representation, where -1 represents
101+
/// "set" and 0 represents "unset".
89102
///
90103
/// # Panics
91104
/// Panics if any lane is not 0 or -1.
92105
#[inline]
93-
pub fn from_int(value: $type) -> Self {
106+
pub fn from_int(value: crate::$type<LANES>) -> Self {
94107
use core::convert::TryInto;
95108
value.try_into().unwrap()
96109
}
97110
}
98111

99-
impl<const $lanes: usize> core::convert::From<bool> for $name<$lanes>
112+
impl<const LANES: usize> core::convert::From<bool> for $name<LANES>
100113
where
101-
$type: crate::LanesAtMost32,
114+
crate::$type<LANES>: crate::LanesAtMost32,
102115
{
103116
fn from(value: bool) -> Self {
104117
Self::splat(value)
105118
}
106119
}
107120

108-
impl<const $lanes: usize> core::convert::TryFrom<$type> for $name<$lanes>
121+
impl<const LANES: usize> core::convert::TryFrom<crate::$type<LANES>> for $name<LANES>
109122
where
110-
$type: crate::LanesAtMost32,
123+
crate::$type<LANES>: crate::LanesAtMost32,
111124
{
112125
type Error = TryFromMaskError;
113-
fn try_from(value: $type) -> Result<Self, Self::Error> {
114-
if value.as_slice().iter().all(|x| *x == 0 || *x == -1) {
126+
fn try_from(value: crate::$type<LANES>) -> Result<Self, Self::Error> {
127+
let valid = (value.lanes_eq(crate::$type::<LANES>::splat(0)) | value.lanes_eq(crate::$type::<LANES>::splat(-1))).all();
128+
if valid {
115129
Ok(Self(value))
116130
} else {
117131
Err(TryFromMaskError(()))
118132
}
119133
}
120134
}
121135

122-
impl<const $lanes: usize> core::convert::From<$name<$lanes>> for $type
136+
impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::$type<LANES>
123137
where
124-
$type: crate::LanesAtMost32,
138+
crate::$type<LANES>: crate::LanesAtMost32,
125139
{
126-
fn from(value: $name<$lanes>) -> Self {
140+
fn from(value: $name<LANES>) -> Self {
127141
value.0
128142
}
129143
}
130144

131-
impl<const $lanes: usize> core::convert::From<crate::BitMask<$lanes>> for $name<$lanes>
145+
impl<const LANES: usize> core::convert::From<crate::BitMask<LANES>> for $name<LANES>
132146
where
133-
$type: crate::LanesAtMost32,
134-
crate::BitMask<$lanes>: crate::LanesAtMost32,
147+
crate::$type<LANES>: crate::LanesAtMost32,
148+
crate::BitMask<LANES>: crate::LanesAtMost32,
135149
{
136-
fn from(value: crate::BitMask<$lanes>) -> Self {
150+
fn from(value: crate::BitMask<LANES>) -> Self {
137151
// TODO use an intrinsic to do this efficiently (with LLVM's sext instruction)
138152
let mut mask = Self::splat(false);
139153
for lane in 0..LANES {
@@ -143,10 +157,10 @@ macro_rules! define_mask {
143157
}
144158
}
145159

146-
impl<const $lanes: usize> core::convert::From<$name<$lanes>> for crate::BitMask<$lanes>
160+
impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::BitMask<LANES>
147161
where
148-
$type: crate::LanesAtMost32,
149-
crate::BitMask<$lanes>: crate::LanesAtMost32,
162+
crate::$type<LANES>: crate::LanesAtMost32,
163+
crate::BitMask<LANES>: crate::LanesAtMost32,
150164
{
151165
fn from(value: $name<$lanes>) -> Self {
152166
// TODO use an intrinsic to do this efficiently (with LLVM's trunc instruction)
@@ -158,9 +172,9 @@ macro_rules! define_mask {
158172
}
159173
}
160174

161-
impl<const $lanes: usize> core::fmt::Debug for $name<$lanes>
175+
impl<const LANES: usize> core::fmt::Debug for $name<LANES>
162176
where
163-
$type: crate::LanesAtMost32,
177+
crate::$type<LANES>: crate::LanesAtMost32,
164178
{
165179
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
166180
f.debug_list()
@@ -169,36 +183,36 @@ macro_rules! define_mask {
169183
}
170184
}
171185

172-
impl<const $lanes: usize> core::fmt::Binary for $name<$lanes>
186+
impl<const LANES: usize> core::fmt::Binary for $name<LANES>
173187
where
174-
$type: crate::LanesAtMost32,
188+
crate::$type<LANES>: crate::LanesAtMost32,
175189
{
176190
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
177191
core::fmt::Binary::fmt(&self.0, f)
178192
}
179193
}
180194

181-
impl<const $lanes: usize> core::fmt::Octal for $name<$lanes>
195+
impl<const LANES: usize> core::fmt::Octal for $name<LANES>
182196
where
183-
$type: crate::LanesAtMost32,
197+
crate::$type<LANES>: crate::LanesAtMost32,
184198
{
185199
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
186200
core::fmt::Octal::fmt(&self.0, f)
187201
}
188202
}
189203

190-
impl<const $lanes: usize> core::fmt::LowerHex for $name<$lanes>
204+
impl<const LANES: usize> core::fmt::LowerHex for $name<LANES>
191205
where
192-
$type: crate::LanesAtMost32,
206+
crate::$type<LANES>: crate::LanesAtMost32,
193207
{
194208
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
195209
core::fmt::LowerHex::fmt(&self.0, f)
196210
}
197211
}
198212

199-
impl<const $lanes: usize> core::fmt::UpperHex for $name<$lanes>
213+
impl<const LANES: usize> core::fmt::UpperHex for $name<LANES>
200214
where
201-
$type: crate::LanesAtMost32,
215+
crate::$type<LANES>: crate::LanesAtMost32,
202216
{
203217
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
204218
core::fmt::UpperHex::fmt(&self.0, f)
@@ -207,7 +221,7 @@ macro_rules! define_mask {
207221

208222
impl<const LANES: usize> core::ops::BitAnd for $name<LANES>
209223
where
210-
$type: crate::LanesAtMost32,
224+
crate::$type<LANES>: crate::LanesAtMost32,
211225
{
212226
type Output = Self;
213227
#[inline]
@@ -218,7 +232,7 @@ macro_rules! define_mask {
218232

219233
impl<const LANES: usize> core::ops::BitAnd<bool> for $name<LANES>
220234
where
221-
$type: crate::LanesAtMost32,
235+
crate::$type<LANES>: crate::LanesAtMost32,
222236
{
223237
type Output = Self;
224238
#[inline]
@@ -229,7 +243,7 @@ macro_rules! define_mask {
229243

230244
impl<const LANES: usize> core::ops::BitAnd<$name<LANES>> for bool
231245
where
232-
$type: crate::LanesAtMost32,
246+
crate::$type<LANES>: crate::LanesAtMost32,
233247
{
234248
type Output = $name<LANES>;
235249
#[inline]
@@ -240,7 +254,7 @@ macro_rules! define_mask {
240254

241255
impl<const LANES: usize> core::ops::BitOr for $name<LANES>
242256
where
243-
$type: crate::LanesAtMost32,
257+
crate::$type<LANES>: crate::LanesAtMost32,
244258
{
245259
type Output = Self;
246260
#[inline]
@@ -251,7 +265,7 @@ macro_rules! define_mask {
251265

252266
impl<const LANES: usize> core::ops::BitOr<bool> for $name<LANES>
253267
where
254-
$type: crate::LanesAtMost32,
268+
crate::$type<LANES>: crate::LanesAtMost32,
255269
{
256270
type Output = Self;
257271
#[inline]
@@ -262,7 +276,7 @@ macro_rules! define_mask {
262276

263277
impl<const LANES: usize> core::ops::BitOr<$name<LANES>> for bool
264278
where
265-
$type: crate::LanesAtMost32,
279+
crate::$type<LANES>: crate::LanesAtMost32,
266280
{
267281
type Output = $name<LANES>;
268282
#[inline]
@@ -273,7 +287,7 @@ macro_rules! define_mask {
273287

274288
impl<const LANES: usize> core::ops::BitXor for $name<LANES>
275289
where
276-
$type: crate::LanesAtMost32,
290+
crate::$type<LANES>: crate::LanesAtMost32,
277291
{
278292
type Output = Self;
279293
#[inline]
@@ -284,7 +298,7 @@ macro_rules! define_mask {
284298

285299
impl<const LANES: usize> core::ops::BitXor<bool> for $name<LANES>
286300
where
287-
$type: crate::LanesAtMost32,
301+
crate::$type<LANES>: crate::LanesAtMost32,
288302
{
289303
type Output = Self;
290304
#[inline]
@@ -295,7 +309,7 @@ macro_rules! define_mask {
295309

296310
impl<const LANES: usize> core::ops::BitXor<$name<LANES>> for bool
297311
where
298-
$type: crate::LanesAtMost32,
312+
crate::$type<LANES>: crate::LanesAtMost32,
299313
{
300314
type Output = $name<LANES>;
301315
#[inline]
@@ -306,7 +320,7 @@ macro_rules! define_mask {
306320

307321
impl<const LANES: usize> core::ops::Not for $name<LANES>
308322
where
309-
$type: crate::LanesAtMost32,
323+
crate::$type<LANES>: crate::LanesAtMost32,
310324
{
311325
type Output = $name<LANES>;
312326
#[inline]
@@ -317,7 +331,7 @@ macro_rules! define_mask {
317331

318332
impl<const LANES: usize> core::ops::BitAndAssign for $name<LANES>
319333
where
320-
$type: crate::LanesAtMost32,
334+
crate::$type<LANES>: crate::LanesAtMost32,
321335
{
322336
#[inline]
323337
fn bitand_assign(&mut self, rhs: Self) {
@@ -327,7 +341,7 @@ macro_rules! define_mask {
327341

328342
impl<const LANES: usize> core::ops::BitAndAssign<bool> for $name<LANES>
329343
where
330-
$type: crate::LanesAtMost32,
344+
crate::$type<LANES>: crate::LanesAtMost32,
331345
{
332346
#[inline]
333347
fn bitand_assign(&mut self, rhs: bool) {
@@ -337,7 +351,7 @@ macro_rules! define_mask {
337351

338352
impl<const LANES: usize> core::ops::BitOrAssign for $name<LANES>
339353
where
340-
$type: crate::LanesAtMost32,
354+
crate::$type<LANES>: crate::LanesAtMost32,
341355
{
342356
#[inline]
343357
fn bitor_assign(&mut self, rhs: Self) {
@@ -347,7 +361,7 @@ macro_rules! define_mask {
347361

348362
impl<const LANES: usize> core::ops::BitOrAssign<bool> for $name<LANES>
349363
where
350-
$type: crate::LanesAtMost32,
364+
crate::$type<LANES>: crate::LanesAtMost32,
351365
{
352366
#[inline]
353367
fn bitor_assign(&mut self, rhs: bool) {
@@ -357,7 +371,7 @@ macro_rules! define_mask {
357371

358372
impl<const LANES: usize> core::ops::BitXorAssign for $name<LANES>
359373
where
360-
$type: crate::LanesAtMost32,
374+
crate::$type<LANES>: crate::LanesAtMost32,
361375
{
362376
#[inline]
363377
fn bitxor_assign(&mut self, rhs: Self) {
@@ -367,13 +381,15 @@ macro_rules! define_mask {
367381

368382
impl<const LANES: usize> core::ops::BitXorAssign<bool> for $name<LANES>
369383
where
370-
$type: crate::LanesAtMost32,
384+
crate::$type<LANES>: crate::LanesAtMost32,
371385
{
372386
#[inline]
373387
fn bitxor_assign(&mut self, rhs: bool) {
374388
*self ^= Self::splat(rhs);
375389
}
376390
}
391+
392+
impl_full_mask_reductions! { $name, $type }
377393
}
378394
}
379395

0 commit comments

Comments
 (0)