Skip to content

Commit 6dfd5de

Browse files
authored
Merge pull request #1928 from sayantn/use-intrinsics
Use SIMD intrinsics whereever possible
2 parents 91f0c19 + 27b1620 commit 6dfd5de

File tree

14 files changed

+426
-507
lines changed

14 files changed

+426
-507
lines changed

crates/core_arch/src/x86/adx.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use stdarch_test::assert_instr;
55
unsafe extern "unadjusted" {
66
#[link_name = "llvm.x86.addcarry.32"]
77
fn llvm_addcarry_u32(a: u8, b: u32, c: u32) -> (u8, u32);
8-
#[link_name = "llvm.x86.addcarryx.u32"]
9-
fn llvm_addcarryx_u32(a: u8, b: u32, c: u32, d: *mut u32) -> u8;
108
#[link_name = "llvm.x86.subborrow.32"]
119
fn llvm_subborrow_u32(a: u8, b: u32, c: u32) -> (u8, u32);
1210
}
@@ -35,7 +33,7 @@ pub unsafe fn _addcarry_u32(c_in: u8, a: u32, b: u32, out: &mut u32) -> u8 {
3533
#[cfg_attr(test, assert_instr(adc))]
3634
#[stable(feature = "simd_x86_adx", since = "1.33.0")]
3735
pub unsafe fn _addcarryx_u32(c_in: u8, a: u32, b: u32, out: &mut u32) -> u8 {
38-
llvm_addcarryx_u32(c_in, a, b, out as *mut _)
36+
_addcarry_u32(c_in, a, b, out)
3937
}
4038

4139
/// Adds unsigned 32-bit integers `a` and `b` with unsigned 8-bit carry-in `c_in`

crates/core_arch/src/x86/avx.rs

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,11 @@ pub fn _mm256_dp_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256 {
587587
#[cfg_attr(test, assert_instr(vhaddpd))]
588588
#[stable(feature = "simd_x86", since = "1.27.0")]
589589
pub fn _mm256_hadd_pd(a: __m256d, b: __m256d) -> __m256d {
590-
unsafe { vhaddpd(a, b) }
590+
unsafe {
591+
let even = simd_shuffle!(a, b, [0, 4, 2, 6]);
592+
let odd = simd_shuffle!(a, b, [1, 5, 3, 7]);
593+
simd_add(even, odd)
594+
}
591595
}
592596

593597
/// Horizontal addition of adjacent pairs in the two packed vectors
@@ -602,7 +606,11 @@ pub fn _mm256_hadd_pd(a: __m256d, b: __m256d) -> __m256d {
602606
#[cfg_attr(test, assert_instr(vhaddps))]
603607
#[stable(feature = "simd_x86", since = "1.27.0")]
604608
pub fn _mm256_hadd_ps(a: __m256, b: __m256) -> __m256 {
605-
unsafe { vhaddps(a, b) }
609+
unsafe {
610+
let even = simd_shuffle!(a, b, [0, 2, 8, 10, 4, 6, 12, 14]);
611+
let odd = simd_shuffle!(a, b, [1, 3, 9, 11, 5, 7, 13, 15]);
612+
simd_add(even, odd)
613+
}
606614
}
607615

608616
/// Horizontal subtraction of adjacent pairs in the two packed vectors
@@ -616,7 +624,11 @@ pub fn _mm256_hadd_ps(a: __m256, b: __m256) -> __m256 {
616624
#[cfg_attr(test, assert_instr(vhsubpd))]
617625
#[stable(feature = "simd_x86", since = "1.27.0")]
618626
pub fn _mm256_hsub_pd(a: __m256d, b: __m256d) -> __m256d {
619-
unsafe { vhsubpd(a, b) }
627+
unsafe {
628+
let even = simd_shuffle!(a, b, [0, 4, 2, 6]);
629+
let odd = simd_shuffle!(a, b, [1, 5, 3, 7]);
630+
simd_sub(even, odd)
631+
}
620632
}
621633

622634
/// Horizontal subtraction of adjacent pairs in the two packed vectors
@@ -631,7 +643,11 @@ pub fn _mm256_hsub_pd(a: __m256d, b: __m256d) -> __m256d {
631643
#[cfg_attr(test, assert_instr(vhsubps))]
632644
#[stable(feature = "simd_x86", since = "1.27.0")]
633645
pub fn _mm256_hsub_ps(a: __m256, b: __m256) -> __m256 {
634-
unsafe { vhsubps(a, b) }
646+
unsafe {
647+
let even = simd_shuffle!(a, b, [0, 2, 8, 10, 4, 6, 12, 14]);
648+
let odd = simd_shuffle!(a, b, [1, 3, 9, 11, 5, 7, 13, 15]);
649+
simd_sub(even, odd)
650+
}
635651
}
636652

637653
/// Computes the bitwise XOR of packed double-precision (64-bit) floating-point
@@ -1218,7 +1234,10 @@ pub fn _mm_permute_pd<const IMM2: i32>(a: __m128d) -> __m128d {
12181234
#[stable(feature = "simd_x86", since = "1.27.0")]
12191235
pub fn _mm256_permute2f128_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256 {
12201236
static_assert_uimm_bits!(IMM8, 8);
1221-
unsafe { vperm2f128ps256(a, b, IMM8 as i8) }
1237+
_mm256_castsi256_ps(_mm256_permute2f128_si256::<IMM8>(
1238+
_mm256_castps_si256(a),
1239+
_mm256_castps_si256(b),
1240+
))
12221241
}
12231242

12241243
/// Shuffles 256 bits (composed of 4 packed double-precision (64-bit)
@@ -1232,7 +1251,10 @@ pub fn _mm256_permute2f128_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256 {
12321251
#[stable(feature = "simd_x86", since = "1.27.0")]
12331252
pub fn _mm256_permute2f128_pd<const IMM8: i32>(a: __m256d, b: __m256d) -> __m256d {
12341253
static_assert_uimm_bits!(IMM8, 8);
1235-
unsafe { vperm2f128pd256(a, b, IMM8 as i8) }
1254+
_mm256_castsi256_pd(_mm256_permute2f128_si256::<IMM8>(
1255+
_mm256_castpd_si256(a),
1256+
_mm256_castpd_si256(b),
1257+
))
12361258
}
12371259

12381260
/// Shuffles 128-bits (composed of integer data) selected by `imm8`
@@ -1246,7 +1268,35 @@ pub fn _mm256_permute2f128_pd<const IMM8: i32>(a: __m256d, b: __m256d) -> __m256
12461268
#[stable(feature = "simd_x86", since = "1.27.0")]
12471269
pub fn _mm256_permute2f128_si256<const IMM8: i32>(a: __m256i, b: __m256i) -> __m256i {
12481270
static_assert_uimm_bits!(IMM8, 8);
1249-
unsafe { transmute(vperm2f128si256(a.as_i32x8(), b.as_i32x8(), IMM8 as i8)) }
1271+
const fn idx(imm8: i32, pos: u32) -> u32 {
1272+
let part = if pos < 2 {
1273+
imm8 & 0xf
1274+
} else {
1275+
(imm8 & 0xf0) >> 4
1276+
};
1277+
2 * (part as u32 & 0b11) + (pos & 1)
1278+
}
1279+
const fn idx0(imm8: i32, pos: u32) -> u32 {
1280+
let part = if pos < 2 {
1281+
imm8 & 0xf
1282+
} else {
1283+
(imm8 & 0xf0) >> 4
1284+
};
1285+
if part & 0b1000 != 0 { 4 } else { pos }
1286+
}
1287+
unsafe {
1288+
let r = simd_shuffle!(
1289+
a.as_i64x4(),
1290+
b.as_i64x4(),
1291+
[idx(IMM8, 0), idx(IMM8, 1), idx(IMM8, 2), idx(IMM8, 3)]
1292+
);
1293+
let r: i64x4 = simd_shuffle!(
1294+
r,
1295+
i64x4::ZERO,
1296+
[idx0(IMM8, 0), idx0(IMM8, 1), idx0(IMM8, 2), idx0(IMM8, 3)]
1297+
);
1298+
r.as_m256i()
1299+
}
12501300
}
12511301

12521302
/// Broadcasts a single-precision (32-bit) floating-point element from memory
@@ -1933,7 +1983,10 @@ pub fn _mm256_unpacklo_ps(a: __m256, b: __m256) -> __m256 {
19331983
#[cfg_attr(test, assert_instr(vptest))]
19341984
#[stable(feature = "simd_x86", since = "1.27.0")]
19351985
pub fn _mm256_testz_si256(a: __m256i, b: __m256i) -> i32 {
1936-
unsafe { ptestz256(a.as_i64x4(), b.as_i64x4()) }
1986+
unsafe {
1987+
let r = simd_and(a.as_i64x4(), b.as_i64x4());
1988+
(0i64 == simd_reduce_or(r)) as i32
1989+
}
19371990
}
19381991

19391992
/// Computes the bitwise AND of 256 bits (representing integer data) in `a` and
@@ -1947,7 +2000,10 @@ pub fn _mm256_testz_si256(a: __m256i, b: __m256i) -> i32 {
19472000
#[cfg_attr(test, assert_instr(vptest))]
19482001
#[stable(feature = "simd_x86", since = "1.27.0")]
19492002
pub fn _mm256_testc_si256(a: __m256i, b: __m256i) -> i32 {
1950-
unsafe { ptestc256(a.as_i64x4(), b.as_i64x4()) }
2003+
unsafe {
2004+
let r = simd_and(simd_xor(a.as_i64x4(), i64x4::splat(!0)), b.as_i64x4());
2005+
(0i64 == simd_reduce_or(r)) as i32
2006+
}
19512007
}
19522008

19532009
/// Computes the bitwise AND of 256 bits (representing integer data) in `a` and
@@ -2031,7 +2087,10 @@ pub fn _mm256_testnzc_pd(a: __m256d, b: __m256d) -> i32 {
20312087
#[cfg_attr(test, assert_instr(vtestpd))]
20322088
#[stable(feature = "simd_x86", since = "1.27.0")]
20332089
pub fn _mm_testz_pd(a: __m128d, b: __m128d) -> i32 {
2034-
unsafe { vtestzpd(a, b) }
2090+
unsafe {
2091+
let r: i64x2 = simd_lt(transmute(_mm_and_pd(a, b)), i64x2::ZERO);
2092+
(0i64 == simd_reduce_or(r)) as i32
2093+
}
20352094
}
20362095

20372096
/// Computes the bitwise AND of 128 bits (representing double-precision (64-bit)
@@ -2048,7 +2107,10 @@ pub fn _mm_testz_pd(a: __m128d, b: __m128d) -> i32 {
20482107
#[cfg_attr(test, assert_instr(vtestpd))]
20492108
#[stable(feature = "simd_x86", since = "1.27.0")]
20502109
pub fn _mm_testc_pd(a: __m128d, b: __m128d) -> i32 {
2051-
unsafe { vtestcpd(a, b) }
2110+
unsafe {
2111+
let r: i64x2 = simd_lt(transmute(_mm_andnot_pd(a, b)), i64x2::ZERO);
2112+
(0i64 == simd_reduce_or(r)) as i32
2113+
}
20522114
}
20532115

20542116
/// Computes the bitwise AND of 128 bits (representing double-precision (64-bit)
@@ -2135,7 +2197,10 @@ pub fn _mm256_testnzc_ps(a: __m256, b: __m256) -> i32 {
21352197
#[cfg_attr(test, assert_instr(vtestps))]
21362198
#[stable(feature = "simd_x86", since = "1.27.0")]
21372199
pub fn _mm_testz_ps(a: __m128, b: __m128) -> i32 {
2138-
unsafe { vtestzps(a, b) }
2200+
unsafe {
2201+
let r: i32x4 = simd_lt(transmute(_mm_and_ps(a, b)), i32x4::ZERO);
2202+
(0i32 == simd_reduce_or(r)) as i32
2203+
}
21392204
}
21402205

21412206
/// Computes the bitwise AND of 128 bits (representing single-precision (32-bit)
@@ -2152,7 +2217,10 @@ pub fn _mm_testz_ps(a: __m128, b: __m128) -> i32 {
21522217
#[cfg_attr(test, assert_instr(vtestps))]
21532218
#[stable(feature = "simd_x86", since = "1.27.0")]
21542219
pub fn _mm_testc_ps(a: __m128, b: __m128) -> i32 {
2155-
unsafe { vtestcps(a, b) }
2220+
unsafe {
2221+
let r: i32x4 = simd_lt(transmute(_mm_andnot_ps(a, b)), i32x4::ZERO);
2222+
(0i32 == simd_reduce_or(r)) as i32
2223+
}
21562224
}
21572225

21582226
/// Computes the bitwise AND of 128 bits (representing single-precision (32-bit)
@@ -3044,14 +3112,6 @@ unsafe extern "C" {
30443112
fn roundps256(a: __m256, b: i32) -> __m256;
30453113
#[link_name = "llvm.x86.avx.dp.ps.256"]
30463114
fn vdpps(a: __m256, b: __m256, imm8: i8) -> __m256;
3047-
#[link_name = "llvm.x86.avx.hadd.pd.256"]
3048-
fn vhaddpd(a: __m256d, b: __m256d) -> __m256d;
3049-
#[link_name = "llvm.x86.avx.hadd.ps.256"]
3050-
fn vhaddps(a: __m256, b: __m256) -> __m256;
3051-
#[link_name = "llvm.x86.avx.hsub.pd.256"]
3052-
fn vhsubpd(a: __m256d, b: __m256d) -> __m256d;
3053-
#[link_name = "llvm.x86.avx.hsub.ps.256"]
3054-
fn vhsubps(a: __m256, b: __m256) -> __m256;
30553115
#[link_name = "llvm.x86.sse2.cmp.pd"]
30563116
fn vcmppd(a: __m128d, b: __m128d, imm8: i8) -> __m128d;
30573117
#[link_name = "llvm.x86.avx.cmp.pd.256"]
@@ -3084,12 +3144,6 @@ unsafe extern "C" {
30843144
fn vpermilpd256(a: __m256d, b: i64x4) -> __m256d;
30853145
#[link_name = "llvm.x86.avx.vpermilvar.pd"]
30863146
fn vpermilpd(a: __m128d, b: i64x2) -> __m128d;
3087-
#[link_name = "llvm.x86.avx.vperm2f128.ps.256"]
3088-
fn vperm2f128ps256(a: __m256, b: __m256, imm8: i8) -> __m256;
3089-
#[link_name = "llvm.x86.avx.vperm2f128.pd.256"]
3090-
fn vperm2f128pd256(a: __m256d, b: __m256d, imm8: i8) -> __m256d;
3091-
#[link_name = "llvm.x86.avx.vperm2f128.si.256"]
3092-
fn vperm2f128si256(a: i32x8, b: i32x8, imm8: i8) -> i32x8;
30933147
#[link_name = "llvm.x86.avx.maskload.pd.256"]
30943148
fn maskloadpd256(mem_addr: *const i8, mask: i64x4) -> __m256d;
30953149
#[link_name = "llvm.x86.avx.maskstore.pd.256"]
@@ -3112,10 +3166,6 @@ unsafe extern "C" {
31123166
fn vrcpps(a: __m256) -> __m256;
31133167
#[link_name = "llvm.x86.avx.rsqrt.ps.256"]
31143168
fn vrsqrtps(a: __m256) -> __m256;
3115-
#[link_name = "llvm.x86.avx.ptestz.256"]
3116-
fn ptestz256(a: i64x4, b: i64x4) -> i32;
3117-
#[link_name = "llvm.x86.avx.ptestc.256"]
3118-
fn ptestc256(a: i64x4, b: i64x4) -> i32;
31193169
#[link_name = "llvm.x86.avx.ptestnzc.256"]
31203170
fn ptestnzc256(a: i64x4, b: i64x4) -> i32;
31213171
#[link_name = "llvm.x86.avx.vtestz.pd.256"]
@@ -3124,10 +3174,6 @@ unsafe extern "C" {
31243174
fn vtestcpd256(a: __m256d, b: __m256d) -> i32;
31253175
#[link_name = "llvm.x86.avx.vtestnzc.pd.256"]
31263176
fn vtestnzcpd256(a: __m256d, b: __m256d) -> i32;
3127-
#[link_name = "llvm.x86.avx.vtestz.pd"]
3128-
fn vtestzpd(a: __m128d, b: __m128d) -> i32;
3129-
#[link_name = "llvm.x86.avx.vtestc.pd"]
3130-
fn vtestcpd(a: __m128d, b: __m128d) -> i32;
31313177
#[link_name = "llvm.x86.avx.vtestnzc.pd"]
31323178
fn vtestnzcpd(a: __m128d, b: __m128d) -> i32;
31333179
#[link_name = "llvm.x86.avx.vtestz.ps.256"]
@@ -3136,10 +3182,6 @@ unsafe extern "C" {
31363182
fn vtestcps256(a: __m256, b: __m256) -> i32;
31373183
#[link_name = "llvm.x86.avx.vtestnzc.ps.256"]
31383184
fn vtestnzcps256(a: __m256, b: __m256) -> i32;
3139-
#[link_name = "llvm.x86.avx.vtestz.ps"]
3140-
fn vtestzps(a: __m128, b: __m128) -> i32;
3141-
#[link_name = "llvm.x86.avx.vtestc.ps"]
3142-
fn vtestcps(a: __m128, b: __m128) -> i32;
31433185
#[link_name = "llvm.x86.avx.vtestnzc.ps"]
31443186
fn vtestnzcps(a: __m128, b: __m128) -> i32;
31453187
#[link_name = "llvm.x86.avx.min.ps.256"]

0 commit comments

Comments
 (0)