Skip to content

Commit 13410d0

Browse files
committed
Use SIMD intrinsics for vperm2 intrinsics
1 parent ee10e3b commit 13410d0

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

crates/core_arch/src/x86/avx.rs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,10 @@ pub fn _mm_permute_pd<const IMM2: i32>(a: __m128d) -> __m128d {
12341234
#[stable(feature = "simd_x86", since = "1.27.0")]
12351235
pub fn _mm256_permute2f128_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256 {
12361236
static_assert_uimm_bits!(IMM8, 8);
1237-
unsafe { vperm2f128ps256(a, b, IMM8 as i8) }
1237+
_mm256_castsi256_ps(_mm256_permute2f128_si256::<IMM8>(
1238+
_mm256_castps_si256(a),
1239+
_mm256_castps_si256(b),
1240+
))
12381241
}
12391242

12401243
/// Shuffles 256 bits (composed of 4 packed double-precision (64-bit)
@@ -1248,7 +1251,10 @@ pub fn _mm256_permute2f128_ps<const IMM8: i32>(a: __m256, b: __m256) -> __m256 {
12481251
#[stable(feature = "simd_x86", since = "1.27.0")]
12491252
pub fn _mm256_permute2f128_pd<const IMM8: i32>(a: __m256d, b: __m256d) -> __m256d {
12501253
static_assert_uimm_bits!(IMM8, 8);
1251-
unsafe { vperm2f128pd256(a, b, IMM8 as i8) }
1254+
_mm256_castsi256_pd(_mm256_permute2f128_si256::<IMM8>(
1255+
_mm256_castpd_si256(a),
1256+
_mm256_castpd_si256(b),
1257+
))
12521258
}
12531259

12541260
/// Shuffles 128-bits (composed of integer data) selected by `imm8`
@@ -1262,7 +1268,35 @@ pub fn _mm256_permute2f128_pd<const IMM8: i32>(a: __m256d, b: __m256d) -> __m256
12621268
#[stable(feature = "simd_x86", since = "1.27.0")]
12631269
pub fn _mm256_permute2f128_si256<const IMM8: i32>(a: __m256i, b: __m256i) -> __m256i {
12641270
static_assert_uimm_bits!(IMM8, 8);
1265-
unsafe { transmute(vperm2f128si256(a.as_i32x8(), b.as_i32x8(), IMM8 as i8)) }
1271+
const fn idx(imm8: i32, pos: u32) -> u32 {
1272+
let part = if pos < 2 {
1273+
imm8 & 0xf
1274+
} else {
1275+
(imm8 & 0xf0) >> 4
1276+
};
1277+
2 * (part as u32 & 0b11) + (pos & 1)
1278+
}
1279+
const fn idx0(imm8: i32, pos: u32) -> u32 {
1280+
let part = if pos < 2 {
1281+
imm8 & 0xf
1282+
} else {
1283+
(imm8 & 0xf0) >> 4
1284+
};
1285+
if part & 0b1000 != 0 { 4 } else { pos }
1286+
}
1287+
unsafe {
1288+
let r = simd_shuffle!(
1289+
a.as_i64x4(),
1290+
b.as_i64x4(),
1291+
[idx(IMM8, 0), idx(IMM8, 1), idx(IMM8, 2), idx(IMM8, 3)]
1292+
);
1293+
let r: i64x4 = simd_shuffle!(
1294+
r,
1295+
i64x4::ZERO,
1296+
[idx0(IMM8, 0), idx0(IMM8, 1), idx0(IMM8, 2), idx0(IMM8, 3)]
1297+
);
1298+
r.as_m256i()
1299+
}
12661300
}
12671301

12681302
/// Broadcasts a single-precision (32-bit) floating-point element from memory
@@ -3092,12 +3126,6 @@ unsafe extern "C" {
30923126
fn vpermilpd256(a: __m256d, b: i64x4) -> __m256d;
30933127
#[link_name = "llvm.x86.avx.vpermilvar.pd"]
30943128
fn vpermilpd(a: __m128d, b: i64x2) -> __m128d;
3095-
#[link_name = "llvm.x86.avx.vperm2f128.ps.256"]
3096-
fn vperm2f128ps256(a: __m256, b: __m256, imm8: i8) -> __m256;
3097-
#[link_name = "llvm.x86.avx.vperm2f128.pd.256"]
3098-
fn vperm2f128pd256(a: __m256d, b: __m256d, imm8: i8) -> __m256d;
3099-
#[link_name = "llvm.x86.avx.vperm2f128.si.256"]
3100-
fn vperm2f128si256(a: i32x8, b: i32x8, imm8: i8) -> i32x8;
31013129
#[link_name = "llvm.x86.avx.maskload.pd.256"]
31023130
fn maskloadpd256(mem_addr: *const i8, mask: i64x4) -> __m256d;
31033131
#[link_name = "llvm.x86.avx.maskstore.pd.256"]

crates/core_arch/src/x86/avx2.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,7 +2330,7 @@ pub fn _mm256_permute4x64_epi64<const IMM8: i32>(a: __m256i) -> __m256i {
23302330
#[stable(feature = "simd_x86", since = "1.27.0")]
23312331
pub fn _mm256_permute2x128_si256<const IMM8: i32>(a: __m256i, b: __m256i) -> __m256i {
23322332
static_assert_uimm_bits!(IMM8, 8);
2333-
unsafe { transmute(vperm2i128(a.as_i64x4(), b.as_i64x4(), IMM8 as i8)) }
2333+
_mm256_permute2f128_si256::<IMM8>(a, b)
23342334
}
23352335

23362336
/// Shuffles 64-bit floating-point elements in `a` across lanes using the
@@ -3703,8 +3703,6 @@ unsafe extern "C" {
37033703
fn permd(a: u32x8, b: u32x8) -> u32x8;
37043704
#[link_name = "llvm.x86.avx2.permps"]
37053705
fn permps(a: __m256, b: i32x8) -> __m256;
3706-
#[link_name = "llvm.x86.avx2.vperm2i128"]
3707-
fn vperm2i128(a: i64x4, b: i64x4, imm8: i8) -> i64x4;
37083706
#[link_name = "llvm.x86.avx2.gather.d.d"]
37093707
fn pgatherdd(src: i32x4, slice: *const i8, offsets: i32x4, mask: i32x4, scale: i8) -> i32x4;
37103708
#[link_name = "llvm.x86.avx2.gather.d.d.256"]

0 commit comments

Comments
 (0)