Skip to content

Commit 1d61f54

Browse files
authored
Merge pull request #1931 from sayantn/use-intrinsics
Fix mistake in #1928
2 parents 03ad8a7 + 6072d4e commit 1d61f54

File tree

4 files changed

+297
-57
lines changed

4 files changed

+297
-57
lines changed

crates/core_arch/src/x86/avx2.rs

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2778,7 +2778,7 @@ pub fn _mm256_bslli_epi128<const IMM8: i32>(a: __m256i) -> __m256i {
27782778
#[cfg_attr(test, assert_instr(vpsllvd))]
27792779
#[stable(feature = "simd_x86", since = "1.27.0")]
27802780
pub fn _mm_sllv_epi32(a: __m128i, count: __m128i) -> __m128i {
2781-
unsafe { transmute(simd_shl(a.as_u32x4(), count.as_u32x4())) }
2781+
unsafe { transmute(psllvd(a.as_i32x4(), count.as_i32x4())) }
27822782
}
27832783

27842784
/// Shifts packed 32-bit integers in `a` left by the amount
@@ -2791,7 +2791,7 @@ pub fn _mm_sllv_epi32(a: __m128i, count: __m128i) -> __m128i {
27912791
#[cfg_attr(test, assert_instr(vpsllvd))]
27922792
#[stable(feature = "simd_x86", since = "1.27.0")]
27932793
pub fn _mm256_sllv_epi32(a: __m256i, count: __m256i) -> __m256i {
2794-
unsafe { transmute(simd_shl(a.as_u32x8(), count.as_u32x8())) }
2794+
unsafe { transmute(psllvd256(a.as_i32x8(), count.as_i32x8())) }
27952795
}
27962796

27972797
/// Shifts packed 64-bit integers in `a` left by the amount
@@ -2804,7 +2804,7 @@ pub fn _mm256_sllv_epi32(a: __m256i, count: __m256i) -> __m256i {
28042804
#[cfg_attr(test, assert_instr(vpsllvq))]
28052805
#[stable(feature = "simd_x86", since = "1.27.0")]
28062806
pub fn _mm_sllv_epi64(a: __m128i, count: __m128i) -> __m128i {
2807-
unsafe { transmute(simd_shl(a.as_u64x2(), count.as_u64x2())) }
2807+
unsafe { transmute(psllvq(a.as_i64x2(), count.as_i64x2())) }
28082808
}
28092809

28102810
/// Shifts packed 64-bit integers in `a` left by the amount
@@ -2817,7 +2817,7 @@ pub fn _mm_sllv_epi64(a: __m128i, count: __m128i) -> __m128i {
28172817
#[cfg_attr(test, assert_instr(vpsllvq))]
28182818
#[stable(feature = "simd_x86", since = "1.27.0")]
28192819
pub fn _mm256_sllv_epi64(a: __m256i, count: __m256i) -> __m256i {
2820-
unsafe { transmute(simd_shl(a.as_u64x4(), count.as_u64x4())) }
2820+
unsafe { transmute(psllvq256(a.as_i64x4(), count.as_i64x4())) }
28212821
}
28222822

28232823
/// Shifts packed 16-bit integers in `a` right by `count` while
@@ -2881,7 +2881,7 @@ pub fn _mm256_srai_epi32<const IMM8: i32>(a: __m256i) -> __m256i {
28812881
#[cfg_attr(test, assert_instr(vpsravd))]
28822882
#[stable(feature = "simd_x86", since = "1.27.0")]
28832883
pub fn _mm_srav_epi32(a: __m128i, count: __m128i) -> __m128i {
2884-
unsafe { transmute(simd_shr(a.as_i32x4(), count.as_i32x4())) }
2884+
unsafe { transmute(psravd(a.as_i32x4(), count.as_i32x4())) }
28852885
}
28862886

28872887
/// Shifts packed 32-bit integers in `a` right by the amount specified by the
@@ -2893,7 +2893,7 @@ pub fn _mm_srav_epi32(a: __m128i, count: __m128i) -> __m128i {
28932893
#[cfg_attr(test, assert_instr(vpsravd))]
28942894
#[stable(feature = "simd_x86", since = "1.27.0")]
28952895
pub fn _mm256_srav_epi32(a: __m256i, count: __m256i) -> __m256i {
2896-
unsafe { transmute(simd_shr(a.as_i32x8(), count.as_i32x8())) }
2896+
unsafe { transmute(psravd256(a.as_i32x8(), count.as_i32x8())) }
28972897
}
28982898

28992899
/// Shifts 128-bit lanes in `a` right by `imm8` bytes while shifting in zeros.
@@ -3076,7 +3076,7 @@ pub fn _mm256_srli_epi64<const IMM8: i32>(a: __m256i) -> __m256i {
30763076
#[cfg_attr(test, assert_instr(vpsrlvd))]
30773077
#[stable(feature = "simd_x86", since = "1.27.0")]
30783078
pub fn _mm_srlv_epi32(a: __m128i, count: __m128i) -> __m128i {
3079-
unsafe { transmute(simd_shr(a.as_u32x4(), count.as_u32x4())) }
3079+
unsafe { transmute(psrlvd(a.as_i32x4(), count.as_i32x4())) }
30803080
}
30813081

30823082
/// Shifts packed 32-bit integers in `a` right by the amount specified by
@@ -3088,7 +3088,7 @@ pub fn _mm_srlv_epi32(a: __m128i, count: __m128i) -> __m128i {
30883088
#[cfg_attr(test, assert_instr(vpsrlvd))]
30893089
#[stable(feature = "simd_x86", since = "1.27.0")]
30903090
pub fn _mm256_srlv_epi32(a: __m256i, count: __m256i) -> __m256i {
3091-
unsafe { transmute(simd_shr(a.as_u32x8(), count.as_u32x8())) }
3091+
unsafe { transmute(psrlvd256(a.as_i32x8(), count.as_i32x8())) }
30923092
}
30933093

30943094
/// Shifts packed 64-bit integers in `a` right by the amount specified by
@@ -3100,7 +3100,7 @@ pub fn _mm256_srlv_epi32(a: __m256i, count: __m256i) -> __m256i {
31003100
#[cfg_attr(test, assert_instr(vpsrlvq))]
31013101
#[stable(feature = "simd_x86", since = "1.27.0")]
31023102
pub fn _mm_srlv_epi64(a: __m128i, count: __m128i) -> __m128i {
3103-
unsafe { transmute(simd_shr(a.as_u64x2(), count.as_u64x2())) }
3103+
unsafe { transmute(psrlvq(a.as_i64x2(), count.as_i64x2())) }
31043104
}
31053105

31063106
/// Shifts packed 64-bit integers in `a` right by the amount specified by
@@ -3112,7 +3112,7 @@ pub fn _mm_srlv_epi64(a: __m128i, count: __m128i) -> __m128i {
31123112
#[cfg_attr(test, assert_instr(vpsrlvq))]
31133113
#[stable(feature = "simd_x86", since = "1.27.0")]
31143114
pub fn _mm256_srlv_epi64(a: __m256i, count: __m256i) -> __m256i {
3115-
unsafe { transmute(simd_shr(a.as_u64x4(), count.as_u64x4())) }
3115+
unsafe { transmute(psrlvq256(a.as_i64x4(), count.as_i64x4())) }
31163116
}
31173117

31183118
/// Load 256-bits of integer data from memory into dst using a non-temporal memory hint. mem_addr
@@ -3687,16 +3687,36 @@ unsafe extern "C" {
36873687
fn pslld(a: i32x8, count: i32x4) -> i32x8;
36883688
#[link_name = "llvm.x86.avx2.psll.q"]
36893689
fn psllq(a: i64x4, count: i64x2) -> i64x4;
3690+
#[link_name = "llvm.x86.avx2.psllv.d"]
3691+
fn psllvd(a: i32x4, count: i32x4) -> i32x4;
3692+
#[link_name = "llvm.x86.avx2.psllv.d.256"]
3693+
fn psllvd256(a: i32x8, count: i32x8) -> i32x8;
3694+
#[link_name = "llvm.x86.avx2.psllv.q"]
3695+
fn psllvq(a: i64x2, count: i64x2) -> i64x2;
3696+
#[link_name = "llvm.x86.avx2.psllv.q.256"]
3697+
fn psllvq256(a: i64x4, count: i64x4) -> i64x4;
36903698
#[link_name = "llvm.x86.avx2.psra.w"]
36913699
fn psraw(a: i16x16, count: i16x8) -> i16x16;
36923700
#[link_name = "llvm.x86.avx2.psra.d"]
36933701
fn psrad(a: i32x8, count: i32x4) -> i32x8;
3702+
#[link_name = "llvm.x86.avx2.psrav.d"]
3703+
fn psravd(a: i32x4, count: i32x4) -> i32x4;
3704+
#[link_name = "llvm.x86.avx2.psrav.d.256"]
3705+
fn psravd256(a: i32x8, count: i32x8) -> i32x8;
36943706
#[link_name = "llvm.x86.avx2.psrl.w"]
36953707
fn psrlw(a: i16x16, count: i16x8) -> i16x16;
36963708
#[link_name = "llvm.x86.avx2.psrl.d"]
36973709
fn psrld(a: i32x8, count: i32x4) -> i32x8;
36983710
#[link_name = "llvm.x86.avx2.psrl.q"]
36993711
fn psrlq(a: i64x4, count: i64x2) -> i64x4;
3712+
#[link_name = "llvm.x86.avx2.psrlv.d"]
3713+
fn psrlvd(a: i32x4, count: i32x4) -> i32x4;
3714+
#[link_name = "llvm.x86.avx2.psrlv.d.256"]
3715+
fn psrlvd256(a: i32x8, count: i32x8) -> i32x8;
3716+
#[link_name = "llvm.x86.avx2.psrlv.q"]
3717+
fn psrlvq(a: i64x2, count: i64x2) -> i64x2;
3718+
#[link_name = "llvm.x86.avx2.psrlv.q.256"]
3719+
fn psrlvq256(a: i64x4, count: i64x4) -> i64x4;
37003720
#[link_name = "llvm.x86.avx2.pshuf.b"]
37013721
fn pshufb(a: u8x32, b: u8x32) -> u8x32;
37023722
#[link_name = "llvm.x86.avx2.permd"]

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6852,7 +6852,7 @@ pub fn _mm_maskz_slli_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i
68526852
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
68536853
#[cfg_attr(test, assert_instr(vpsllvw))]
68546854
pub fn _mm512_sllv_epi16(a: __m512i, count: __m512i) -> __m512i {
6855-
unsafe { transmute(simd_shl(a.as_u16x32(), count.as_u16x32())) }
6855+
unsafe { transmute(vpsllvw(a.as_i16x32(), count.as_i16x32())) }
68566856
}
68576857

68586858
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6891,7 +6891,7 @@ pub fn _mm512_maskz_sllv_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
68916891
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
68926892
#[cfg_attr(test, assert_instr(vpsllvw))]
68936893
pub fn _mm256_sllv_epi16(a: __m256i, count: __m256i) -> __m256i {
6894-
unsafe { transmute(simd_shl(a.as_u16x16(), count.as_u16x16())) }
6894+
unsafe { transmute(vpsllvw256(a.as_i16x16(), count.as_i16x16())) }
68956895
}
68966896

68976897
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -6930,7 +6930,7 @@ pub fn _mm256_maskz_sllv_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
69306930
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
69316931
#[cfg_attr(test, assert_instr(vpsllvw))]
69326932
pub fn _mm_sllv_epi16(a: __m128i, count: __m128i) -> __m128i {
6933-
unsafe { transmute(simd_shl(a.as_u16x8(), count.as_u16x8())) }
6933+
unsafe { transmute(vpsllvw128(a.as_i16x8(), count.as_i16x8())) }
69346934
}
69356935

69366936
/// Shift packed 16-bit integers in a left by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7188,7 +7188,7 @@ pub fn _mm_maskz_srli_epi16<const IMM8: i32>(k: __mmask8, a: __m128i) -> __m128i
71887188
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
71897189
#[cfg_attr(test, assert_instr(vpsrlvw))]
71907190
pub fn _mm512_srlv_epi16(a: __m512i, count: __m512i) -> __m512i {
7191-
unsafe { transmute(simd_shr(a.as_u16x32(), count.as_u16x32())) }
7191+
unsafe { transmute(vpsrlvw(a.as_i16x32(), count.as_i16x32())) }
71927192
}
71937193

71947194
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7227,7 +7227,7 @@ pub fn _mm512_maskz_srlv_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
72277227
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
72287228
#[cfg_attr(test, assert_instr(vpsrlvw))]
72297229
pub fn _mm256_srlv_epi16(a: __m256i, count: __m256i) -> __m256i {
7230-
unsafe { transmute(simd_shr(a.as_u16x16(), count.as_u16x16())) }
7230+
unsafe { transmute(vpsrlvw256(a.as_i16x16(), count.as_i16x16())) }
72317231
}
72327232

72337233
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7266,7 +7266,7 @@ pub fn _mm256_maskz_srlv_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
72667266
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
72677267
#[cfg_attr(test, assert_instr(vpsrlvw))]
72687268
pub fn _mm_srlv_epi16(a: __m128i, count: __m128i) -> __m128i {
7269-
unsafe { transmute(simd_shr(a.as_u16x8(), count.as_u16x8())) }
7269+
unsafe { transmute(vpsrlvw128(a.as_i16x8(), count.as_i16x8())) }
72707270
}
72717271

72727272
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in zeros, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7511,7 +7511,7 @@ pub fn _mm_maskz_srai_epi16<const IMM8: u32>(k: __mmask8, a: __m128i) -> __m128i
75117511
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75127512
#[cfg_attr(test, assert_instr(vpsravw))]
75137513
pub fn _mm512_srav_epi16(a: __m512i, count: __m512i) -> __m512i {
7514-
unsafe { transmute(simd_shr(a.as_i16x32(), count.as_i16x32())) }
7514+
unsafe { transmute(vpsravw(a.as_i16x32(), count.as_i16x32())) }
75157515
}
75167516

75177517
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7550,7 +7550,7 @@ pub fn _mm512_maskz_srav_epi16(k: __mmask32, a: __m512i, count: __m512i) -> __m5
75507550
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75517551
#[cfg_attr(test, assert_instr(vpsravw))]
75527552
pub fn _mm256_srav_epi16(a: __m256i, count: __m256i) -> __m256i {
7553-
unsafe { transmute(simd_shr(a.as_i16x16(), count.as_i16x16())) }
7553+
unsafe { transmute(vpsravw256(a.as_i16x16(), count.as_i16x16())) }
75547554
}
75557555

75567556
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -7589,7 +7589,7 @@ pub fn _mm256_maskz_srav_epi16(k: __mmask16, a: __m256i, count: __m256i) -> __m2
75897589
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
75907590
#[cfg_attr(test, assert_instr(vpsravw))]
75917591
pub fn _mm_srav_epi16(a: __m128i, count: __m128i) -> __m128i {
7592-
unsafe { transmute(simd_shr(a.as_i16x8(), count.as_i16x8())) }
7592+
unsafe { transmute(vpsravw128(a.as_i16x8(), count.as_i16x8())) }
75937593
}
75947594

75957595
/// Shift packed 16-bit integers in a right by the amount specified by the corresponding element in count while shifting in sign bits, and store the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
@@ -11645,12 +11645,33 @@ unsafe extern "C" {
1164511645
#[link_name = "llvm.x86.avx512.psll.w.512"]
1164611646
fn vpsllw(a: i16x32, count: i16x8) -> i16x32;
1164711647

11648+
#[link_name = "llvm.x86.avx512.psllv.w.512"]
11649+
fn vpsllvw(a: i16x32, b: i16x32) -> i16x32;
11650+
#[link_name = "llvm.x86.avx512.psllv.w.256"]
11651+
fn vpsllvw256(a: i16x16, b: i16x16) -> i16x16;
11652+
#[link_name = "llvm.x86.avx512.psllv.w.128"]
11653+
fn vpsllvw128(a: i16x8, b: i16x8) -> i16x8;
11654+
1164811655
#[link_name = "llvm.x86.avx512.psrl.w.512"]
1164911656
fn vpsrlw(a: i16x32, count: i16x8) -> i16x32;
1165011657

11658+
#[link_name = "llvm.x86.avx512.psrlv.w.512"]
11659+
fn vpsrlvw(a: i16x32, b: i16x32) -> i16x32;
11660+
#[link_name = "llvm.x86.avx512.psrlv.w.256"]
11661+
fn vpsrlvw256(a: i16x16, b: i16x16) -> i16x16;
11662+
#[link_name = "llvm.x86.avx512.psrlv.w.128"]
11663+
fn vpsrlvw128(a: i16x8, b: i16x8) -> i16x8;
11664+
1165111665
#[link_name = "llvm.x86.avx512.psra.w.512"]
1165211666
fn vpsraw(a: i16x32, count: i16x8) -> i16x32;
1165311667

11668+
#[link_name = "llvm.x86.avx512.psrav.w.512"]
11669+
fn vpsravw(a: i16x32, count: i16x32) -> i16x32;
11670+
#[link_name = "llvm.x86.avx512.psrav.w.256"]
11671+
fn vpsravw256(a: i16x16, count: i16x16) -> i16x16;
11672+
#[link_name = "llvm.x86.avx512.psrav.w.128"]
11673+
fn vpsravw128(a: i16x8, count: i16x8) -> i16x8;
11674+
1165411675
#[link_name = "llvm.x86.avx512.vpermi2var.hi.512"]
1165511676
fn vpermi2w(a: i16x32, idx: i16x32, b: i16x32) -> i16x32;
1165611677
#[link_name = "llvm.x86.avx512.vpermi2var.hi.256"]

0 commit comments

Comments
 (0)