Skip to content

Commit 40a3048

Browse files
committed
Implement arithmetic operation traits for x86 SIMD types
1 parent 97bf36d commit 40a3048

File tree

5 files changed

+108
-34
lines changed

5 files changed

+108
-34
lines changed

crates/core_arch/src/macros.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,60 @@ macro_rules! simd_extract {
163163
($x:expr, $idx:expr $(,)?) => {{ $crate::intrinsics::simd::simd_extract($x, const { $idx }) }};
164164
($x:expr, $idx:expr, $ty:ty $(,)?) => {{ $crate::intrinsics::simd::simd_extract::<_, $ty>($x, const { $idx }) }};
165165
}
166+
167+
#[allow(unused)]
168+
macro_rules! impl_arith_op {
169+
(__internal $op:ident, $intrinsic:ident $_:ident) => {
170+
#[inline]
171+
fn $op(self, rhs: Self) -> Self {
172+
unsafe { crate::intrinsics::simd::$intrinsic(self, rhs) }
173+
}
174+
};
175+
(__internal $op:ident, $intrinsic:ident) => {
176+
#[inline]
177+
fn $op(self) -> Self {
178+
unsafe { crate::intrinsics::simd::$intrinsic(self) }
179+
}
180+
};
181+
(: $($tt:tt)*) => {};
182+
(
183+
$type:ty $(, $other_types:ty )* : $(
184+
$Trait:ident, $op:ident $(, $TraitAssign:ident, $op_assign:ident)? = $intrinsic:ident
185+
);* $(;)?
186+
) => {
187+
$(
188+
#[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")]
189+
impl crate::ops::$Trait for $type {
190+
type Output = Self;
191+
192+
impl_arith_op!(__internal $op, $intrinsic $( $TraitAssign )?);
193+
}
194+
195+
$(
196+
#[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")]
197+
impl crate::ops::$TraitAssign for $type {
198+
#[inline]
199+
fn $op_assign(&mut self, rhs: Self) {
200+
*self = crate::ops::$Trait::$op(*self, rhs)
201+
}
202+
}
203+
)?
204+
)*
205+
206+
impl_arith_op!($($other_types),* : $($Trait, $op $(, $TraitAssign, $op_assign)? = $intrinsic);*);
207+
};
208+
}
209+
210+
macro_rules! impl_not {
211+
($($type:ty),*) => {$(
212+
#[stable(feature = "stdarch_arith_ops", since = "CURRENT_RUSTC_VERSION")]
213+
impl crate::ops::Not for $type {
214+
type Output = Self;
215+
216+
#[inline]
217+
fn not(self) -> Self {
218+
unsafe { crate::intrinsics::simd::simd_xor(<$type>::splat(!0), self) }
219+
}
220+
}
221+
)*};
222+
}

crates/core_arch/src/x86/avx2.rs

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ pub fn _mm256_alignr_epi8<const IMM8: i32>(a: __m256i, b: __m256i) -> __m256i {
248248
#[cfg_attr(test, assert_instr(vandps))]
249249
#[stable(feature = "simd_x86", since = "1.27.0")]
250250
pub fn _mm256_and_si256(a: __m256i, b: __m256i) -> __m256i {
251-
unsafe { transmute(simd_and(a.as_i64x4(), b.as_i64x4())) }
251+
a & b
252252
}
253253

254254
/// Computes the bitwise NOT of 256 bits (representing integer data)
@@ -260,13 +260,7 @@ pub fn _mm256_and_si256(a: __m256i, b: __m256i) -> __m256i {
260260
#[cfg_attr(test, assert_instr(vandnps))]
261261
#[stable(feature = "simd_x86", since = "1.27.0")]
262262
pub fn _mm256_andnot_si256(a: __m256i, b: __m256i) -> __m256i {
263-
unsafe {
264-
let all_ones = _mm256_set1_epi8(-1);
265-
transmute(simd_and(
266-
simd_xor(a.as_i64x4(), all_ones.as_i64x4()),
267-
b.as_i64x4(),
268-
))
269-
}
263+
!a & b
270264
}
271265

272266
/// Averages packed unsigned 16-bit integers in `a` and `b`.
@@ -2184,7 +2178,7 @@ pub fn _mm256_mulhrs_epi16(a: __m256i, b: __m256i) -> __m256i {
21842178
#[cfg_attr(test, assert_instr(vorps))]
21852179
#[stable(feature = "simd_x86", since = "1.27.0")]
21862180
pub fn _mm256_or_si256(a: __m256i, b: __m256i) -> __m256i {
2187-
unsafe { transmute(simd_or(a.as_i32x8(), b.as_i32x8())) }
2181+
a | b
21882182
}
21892183

21902184
/// Converts packed 16-bit integers from `a` and `b` to packed 8-bit integers
@@ -3557,7 +3551,7 @@ pub fn _mm256_unpacklo_epi64(a: __m256i, b: __m256i) -> __m256i {
35573551
#[cfg_attr(test, assert_instr(vxorps))]
35583552
#[stable(feature = "simd_x86", since = "1.27.0")]
35593553
pub fn _mm256_xor_si256(a: __m256i, b: __m256i) -> __m256i {
3560-
unsafe { transmute(simd_xor(a.as_i64x4(), b.as_i64x4())) }
3554+
a ^ b
35613555
}
35623556

35633557
/// Extracts an 8-bit integer from `a`, selected with `INDEX`. Returns a 32-bit

crates/core_arch/src/x86/avx512f.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28149,7 +28149,7 @@ pub fn _mm_maskz_alignr_epi64<const IMM8: i32>(k: __mmask8, a: __m128i, b: __m12
2814928149
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2815028150
#[cfg_attr(test, assert_instr(vpandq))] //should be vpandd, but generate vpandq
2815128151
pub fn _mm512_and_epi32(a: __m512i, b: __m512i) -> __m512i {
28152-
unsafe { transmute(simd_and(a.as_i32x16(), b.as_i32x16())) }
28152+
a & b
2815328153
}
2815428154

2815528155
/// Performs element-by-element bitwise AND between packed 32-bit integer elements of a and b, storing the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28244,7 +28244,7 @@ pub fn _mm_maskz_and_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2824428244
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2824528245
#[cfg_attr(test, assert_instr(vpandq))]
2824628246
pub fn _mm512_and_epi64(a: __m512i, b: __m512i) -> __m512i {
28247-
unsafe { transmute(simd_and(a.as_i64x8(), b.as_i64x8())) }
28247+
a & b
2824828248
}
2824928249

2825028250
/// Compute the bitwise AND of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28339,7 +28339,7 @@ pub fn _mm_maskz_and_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2833928339
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2834028340
#[cfg_attr(test, assert_instr(vpandq))]
2834128341
pub fn _mm512_and_si512(a: __m512i, b: __m512i) -> __m512i {
28342-
unsafe { transmute(simd_and(a.as_i32x16(), b.as_i32x16())) }
28342+
a & b
2834328343
}
2834428344

2834528345
/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst.
@@ -28350,7 +28350,7 @@ pub fn _mm512_and_si512(a: __m512i, b: __m512i) -> __m512i {
2835028350
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2835128351
#[cfg_attr(test, assert_instr(vporq))]
2835228352
pub fn _mm512_or_epi32(a: __m512i, b: __m512i) -> __m512i {
28353-
unsafe { transmute(simd_or(a.as_i32x16(), b.as_i32x16())) }
28353+
a | b
2835428354
}
2835528355

2835628356
/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28389,7 +28389,7 @@ pub fn _mm512_maskz_or_epi32(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
2838928389
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2839028390
#[cfg_attr(test, assert_instr(vor))] //should be vpord
2839128391
pub fn _mm256_or_epi32(a: __m256i, b: __m256i) -> __m256i {
28392-
unsafe { transmute(simd_or(a.as_i32x8(), b.as_i32x8())) }
28392+
a | b
2839328393
}
2839428394

2839528395
/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28428,7 +28428,7 @@ pub fn _mm256_maskz_or_epi32(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
2842828428
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2842928429
#[cfg_attr(test, assert_instr(vor))] //should be vpord
2843028430
pub fn _mm_or_epi32(a: __m128i, b: __m128i) -> __m128i {
28431-
unsafe { transmute(simd_or(a.as_i32x4(), b.as_i32x4())) }
28431+
a | b
2843228432
}
2843328433

2843428434
/// Compute the bitwise OR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28467,7 +28467,7 @@ pub fn _mm_maskz_or_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2846728467
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2846828468
#[cfg_attr(test, assert_instr(vporq))]
2846928469
pub fn _mm512_or_epi64(a: __m512i, b: __m512i) -> __m512i {
28470-
unsafe { transmute(simd_or(a.as_i64x8(), b.as_i64x8())) }
28470+
a | b
2847128471
}
2847228472

2847328473
/// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28506,7 +28506,7 @@ pub fn _mm512_maskz_or_epi64(k: __mmask8, a: __m512i, b: __m512i) -> __m512i {
2850628506
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2850728507
#[cfg_attr(test, assert_instr(vor))] //should be vporq
2850828508
pub fn _mm256_or_epi64(a: __m256i, b: __m256i) -> __m256i {
28509-
unsafe { transmute(simd_or(a.as_i64x4(), b.as_i64x4())) }
28509+
a | b
2851028510
}
2851128511

2851228512
/// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28545,7 +28545,7 @@ pub fn _mm256_maskz_or_epi64(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
2854528545
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2854628546
#[cfg_attr(test, assert_instr(vor))] //should be vporq
2854728547
pub fn _mm_or_epi64(a: __m128i, b: __m128i) -> __m128i {
28548-
unsafe { transmute(simd_or(a.as_i64x2(), b.as_i64x2())) }
28548+
a | b
2854928549
}
2855028550

2855128551
/// Compute the bitwise OR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28584,7 +28584,7 @@ pub fn _mm_maskz_or_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2858428584
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2858528585
#[cfg_attr(test, assert_instr(vporq))]
2858628586
pub fn _mm512_or_si512(a: __m512i, b: __m512i) -> __m512i {
28587-
unsafe { transmute(simd_or(a.as_i32x16(), b.as_i32x16())) }
28587+
a | b
2858828588
}
2858928589

2859028590
/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst.
@@ -28595,7 +28595,7 @@ pub fn _mm512_or_si512(a: __m512i, b: __m512i) -> __m512i {
2859528595
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2859628596
#[cfg_attr(test, assert_instr(vpxorq))] //should be vpxord
2859728597
pub fn _mm512_xor_epi32(a: __m512i, b: __m512i) -> __m512i {
28598-
unsafe { transmute(simd_xor(a.as_i32x16(), b.as_i32x16())) }
28598+
a ^ b
2859928599
}
2860028600

2860128601
/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28634,7 +28634,7 @@ pub fn _mm512_maskz_xor_epi32(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
2863428634
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2863528635
#[cfg_attr(test, assert_instr(vxor))] //should be vpxord
2863628636
pub fn _mm256_xor_epi32(a: __m256i, b: __m256i) -> __m256i {
28637-
unsafe { transmute(simd_xor(a.as_i32x8(), b.as_i32x8())) }
28637+
a ^ b
2863828638
}
2863928639

2864028640
/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28673,7 +28673,7 @@ pub fn _mm256_maskz_xor_epi32(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
2867328673
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2867428674
#[cfg_attr(test, assert_instr(vxor))] //should be vpxord
2867528675
pub fn _mm_xor_epi32(a: __m128i, b: __m128i) -> __m128i {
28676-
unsafe { transmute(simd_xor(a.as_i32x4(), b.as_i32x4())) }
28676+
a ^ b
2867728677
}
2867828678

2867928679
/// Compute the bitwise XOR of packed 32-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28712,7 +28712,7 @@ pub fn _mm_maskz_xor_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2871228712
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2871328713
#[cfg_attr(test, assert_instr(vpxorq))]
2871428714
pub fn _mm512_xor_epi64(a: __m512i, b: __m512i) -> __m512i {
28715-
unsafe { transmute(simd_xor(a.as_i64x8(), b.as_i64x8())) }
28715+
a ^ b
2871628716
}
2871728717

2871828718
/// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28751,7 +28751,7 @@ pub fn _mm512_maskz_xor_epi64(k: __mmask8, a: __m512i, b: __m512i) -> __m512i {
2875128751
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2875228752
#[cfg_attr(test, assert_instr(vxor))] //should be vpxorq
2875328753
pub fn _mm256_xor_epi64(a: __m256i, b: __m256i) -> __m256i {
28754-
unsafe { transmute(simd_xor(a.as_i64x4(), b.as_i64x4())) }
28754+
a ^ b
2875528755
}
2875628756

2875728757
/// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28790,7 +28790,7 @@ pub fn _mm256_maskz_xor_epi64(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
2879028790
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2879128791
#[cfg_attr(test, assert_instr(vxor))] //should be vpxorq
2879228792
pub fn _mm_xor_epi64(a: __m128i, b: __m128i) -> __m128i {
28793-
unsafe { transmute(simd_xor(a.as_i64x2(), b.as_i64x2())) }
28793+
a ^ b
2879428794
}
2879528795

2879628796
/// Compute the bitwise XOR of packed 64-bit integers in a and b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28829,7 +28829,7 @@ pub fn _mm_maskz_xor_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2882928829
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2883028830
#[cfg_attr(test, assert_instr(vpxorq))]
2883128831
pub fn _mm512_xor_si512(a: __m512i, b: __m512i) -> __m512i {
28832-
unsafe { transmute(simd_xor(a.as_i32x16(), b.as_i32x16())) }
28832+
a ^ b
2883328833
}
2883428834

2883528835
/// Compute the bitwise NOT of packed 32-bit integers in a and then AND with b, and store the results in dst.
@@ -28840,7 +28840,7 @@ pub fn _mm512_xor_si512(a: __m512i, b: __m512i) -> __m512i {
2884028840
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2884128841
#[cfg_attr(test, assert_instr(vpandnq))] //should be vpandnd
2884228842
pub fn _mm512_andnot_epi32(a: __m512i, b: __m512i) -> __m512i {
28843-
_mm512_and_epi32(_mm512_xor_epi32(a, _mm512_set1_epi32(u32::MAX as i32)), b)
28843+
!a & b
2884428844
}
2884528845

2884628846
/// Compute the bitwise NOT of packed 32-bit integers in a and then AND with b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -28939,7 +28939,7 @@ pub fn _mm_maskz_andnot_epi32(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2893928939
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2894028940
#[cfg_attr(test, assert_instr(vpandnq))] //should be vpandnd
2894128941
pub fn _mm512_andnot_epi64(a: __m512i, b: __m512i) -> __m512i {
28942-
_mm512_and_epi64(_mm512_xor_epi64(a, _mm512_set1_epi64(u64::MAX as i64)), b)
28942+
!a & b
2894328943
}
2894428944

2894528945
/// Compute the bitwise NOT of packed 64-bit integers in a and then AND with b, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -29038,7 +29038,7 @@ pub fn _mm_maskz_andnot_epi64(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
2903829038
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2903929039
#[cfg_attr(test, assert_instr(vpandnq))]
2904029040
pub fn _mm512_andnot_si512(a: __m512i, b: __m512i) -> __m512i {
29041-
_mm512_and_epi64(_mm512_xor_epi64(a, _mm512_set1_epi64(u64::MAX as i64)), b)
29041+
!a & b
2904229042
}
2904329043

2904429044
/// Convert 16-bit mask a into an integer value, and store the result in dst.

crates/core_arch/src/x86/mod.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,29 @@ impl bf16 {
471471
}
472472
}
473473

474+
impl_arith_op!(
475+
__m128, __m128d, __m128h,
476+
__m256, __m256d, __m256h,
477+
__m512, __m512d, __m512h:
478+
Add, add, AddAssign, add_assign = simd_add;
479+
Sub, sub, SubAssign, sub_assign = simd_sub;
480+
Mul, mul, MulAssign, mul_assign = simd_mul;
481+
Div, div, DivAssign, div_assign = simd_div;
482+
Rem, rem, RemAssign, rem_assign = simd_rem;
483+
Neg, neg = simd_neg;
484+
);
485+
486+
impl_arith_op!(
487+
__m128i, __m256i, __m512i:
488+
BitOr, bitor, BitOrAssign, bitor_assign = simd_or;
489+
BitAnd, bitand, BitAndAssign, bitand_assign = simd_and;
490+
BitXor, bitxor, BitXorAssign, bitxor_assign = simd_xor;
491+
);
492+
493+
impl_not!(__m128i, __m256i, __m512i);
494+
495+
// TODO: should we have `Rem` and `Not`?
496+
474497
/// The `__mmask64` type used in AVX-512 intrinsics, a 64-bit integer
475498
#[allow(non_camel_case_types)]
476499
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]

crates/core_arch/src/x86/sse2.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ pub fn _mm_srl_epi64(a: __m128i, count: __m128i) -> __m128i {
823823
#[cfg_attr(test, assert_instr(andps))]
824824
#[stable(feature = "simd_x86", since = "1.27.0")]
825825
pub fn _mm_and_si128(a: __m128i, b: __m128i) -> __m128i {
826-
unsafe { simd_and(a, b) }
826+
a & b
827827
}
828828

829829
/// Computes the bitwise NOT of 128 bits (representing integer data) in `a` and
@@ -835,7 +835,7 @@ pub fn _mm_and_si128(a: __m128i, b: __m128i) -> __m128i {
835835
#[cfg_attr(test, assert_instr(andnps))]
836836
#[stable(feature = "simd_x86", since = "1.27.0")]
837837
pub fn _mm_andnot_si128(a: __m128i, b: __m128i) -> __m128i {
838-
unsafe { simd_and(simd_xor(_mm_set1_epi8(-1), a), b) }
838+
!a & b
839839
}
840840

841841
/// Computes the bitwise OR of 128 bits (representing integer data) in `a` and
@@ -847,7 +847,7 @@ pub fn _mm_andnot_si128(a: __m128i, b: __m128i) -> __m128i {
847847
#[cfg_attr(test, assert_instr(orps))]
848848
#[stable(feature = "simd_x86", since = "1.27.0")]
849849
pub fn _mm_or_si128(a: __m128i, b: __m128i) -> __m128i {
850-
unsafe { simd_or(a, b) }
850+
a | b
851851
}
852852

853853
/// Computes the bitwise XOR of 128 bits (representing integer data) in `a` and
@@ -859,7 +859,7 @@ pub fn _mm_or_si128(a: __m128i, b: __m128i) -> __m128i {
859859
#[cfg_attr(test, assert_instr(xorps))]
860860
#[stable(feature = "simd_x86", since = "1.27.0")]
861861
pub fn _mm_xor_si128(a: __m128i, b: __m128i) -> __m128i {
862-
unsafe { simd_xor(a, b) }
862+
a ^ b
863863
}
864864

865865
/// Compares packed 8-bit integers in `a` and `b` for equality.

0 commit comments

Comments
 (0)