Skip to content

Commit 03ad8a7

Browse files
authored
Merge pull request #1930 from madhav-madhusoodanan/x86_fix_kshift_instructions
`core_arch::x86` : Fix the implementation of `_kshift` instructions
2 parents 6dfd5de + 29027b6 commit 03ad8a7

File tree

3 files changed

+116
-16
lines changed

3 files changed

+116
-16
lines changed

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10440,7 +10440,7 @@ pub fn _kortestz_mask64_u8(a: __mmask64, b: __mmask64) -> u8 {
1044010440
#[rustc_legacy_const_generics(1)]
1044110441
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1044210442
pub fn _kshiftli_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
10443-
a << COUNT
10443+
a.unbounded_shl(COUNT)
1044410444
}
1044510445

1044610446
/// Shift the bits of 64-bit mask a left by count while shifting in zeros, and store the least significant 32 bits of the result in k.
@@ -10451,7 +10451,7 @@ pub fn _kshiftli_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
1045110451
#[rustc_legacy_const_generics(1)]
1045210452
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1045310453
pub fn _kshiftli_mask64<const COUNT: u32>(a: __mmask64) -> __mmask64 {
10454-
a << COUNT
10454+
a.unbounded_shl(COUNT)
1045510455
}
1045610456

1045710457
/// Shift the bits of 32-bit mask a right by count while shifting in zeros, and store the least significant 32 bits of the result in k.
@@ -10462,7 +10462,7 @@ pub fn _kshiftli_mask64<const COUNT: u32>(a: __mmask64) -> __mmask64 {
1046210462
#[rustc_legacy_const_generics(1)]
1046310463
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1046410464
pub fn _kshiftri_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
10465-
a >> COUNT
10465+
a.unbounded_shr(COUNT)
1046610466
}
1046710467

1046810468
/// Shift the bits of 64-bit mask a right by count while shifting in zeros, and store the least significant 32 bits of the result in k.
@@ -10473,7 +10473,7 @@ pub fn _kshiftri_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
1047310473
#[rustc_legacy_const_generics(1)]
1047410474
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1047510475
pub fn _kshiftri_mask64<const COUNT: u32>(a: __mmask64) -> __mmask64 {
10476-
a >> COUNT
10476+
a.unbounded_shr(COUNT)
1047710477
}
1047810478

1047910479
/// Compute the bitwise AND of 32-bit masks a and b, and if the result is all zeros, store 1 in dst,
@@ -20315,6 +20315,18 @@ mod tests {
2031520315
let r = _kshiftli_mask32::<3>(a);
2031620316
let e: __mmask32 = 0b0100101101001011_0100101101001000;
2031720317
assert_eq!(r, e);
20318+
20319+
let r = _kshiftli_mask32::<31>(a);
20320+
let e: __mmask32 = 0b1000000000000000_0000000000000000;
20321+
assert_eq!(r, e);
20322+
20323+
let r = _kshiftli_mask32::<32>(a);
20324+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
20325+
assert_eq!(r, e);
20326+
20327+
let r = _kshiftli_mask32::<33>(a);
20328+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
20329+
assert_eq!(r, e);
2031820330
}
2031920331

2032020332
#[simd_test(enable = "avx512bw")]
@@ -20323,21 +20335,61 @@ mod tests {
2032320335
let r = _kshiftli_mask64::<3>(a);
2032420336
let e: __mmask64 = 0b0110100101101001011_0100101101001000;
2032520337
assert_eq!(r, e);
20338+
20339+
let r = _kshiftli_mask64::<63>(a);
20340+
let e: __mmask64 = 0b1000000000000000_0000000000000000_0000000000000000_0000000000000000;
20341+
assert_eq!(r, e);
20342+
20343+
let r = _kshiftli_mask64::<64>(a);
20344+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20345+
assert_eq!(r, e);
20346+
20347+
let r = _kshiftli_mask64::<65>(a);
20348+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20349+
assert_eq!(r, e);
2032620350
}
2032720351

2032820352
#[simd_test(enable = "avx512bw")]
2032920353
unsafe fn test_kshiftri_mask32() {
20330-
let a: __mmask32 = 0b0110100101101001_0110100101101001;
20354+
let a: __mmask32 = 0b1010100101101001_0110100101101001;
2033120355
let r = _kshiftri_mask32::<3>(a);
20332-
let e: __mmask32 = 0b0000110100101101_0010110100101101;
20356+
let e: __mmask32 = 0b0001010100101101_0010110100101101;
20357+
assert_eq!(r, e);
20358+
20359+
let r = _kshiftri_mask32::<31>(a);
20360+
let e: __mmask32 = 0b0000000000000000_0000000000000001;
20361+
assert_eq!(r, e);
20362+
20363+
let r = _kshiftri_mask32::<32>(a);
20364+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
20365+
assert_eq!(r, e);
20366+
20367+
let r = _kshiftri_mask32::<33>(a);
20368+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
2033320369
assert_eq!(r, e);
2033420370
}
2033520371

2033620372
#[simd_test(enable = "avx512bw")]
2033720373
unsafe fn test_kshiftri_mask64() {
20338-
let a: __mmask64 = 0b0110100101101001011_0100101101001000;
20374+
let a: __mmask64 = 0b1010100101101001011_0100101101001000;
2033920375
let r = _kshiftri_mask64::<3>(a);
20340-
let e: __mmask64 = 0b0110100101101001_0110100101101001;
20376+
let e: __mmask64 = 0b1010100101101001_0110100101101001;
20377+
assert_eq!(r, e);
20378+
20379+
let r = _kshiftri_mask64::<34>(a);
20380+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000001;
20381+
assert_eq!(r, e);
20382+
20383+
let r = _kshiftri_mask64::<35>(a);
20384+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20385+
assert_eq!(r, e);
20386+
20387+
let r = _kshiftri_mask64::<64>(a);
20388+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20389+
assert_eq!(r, e);
20390+
20391+
let r = _kshiftri_mask64::<65>(a);
20392+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
2034120393
assert_eq!(r, e);
2034220394
}
2034320395

crates/core_arch/src/x86/avx512dq.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4602,7 +4602,7 @@ pub fn _kortestz_mask8_u8(a: __mmask8, b: __mmask8) -> u8 {
46024602
#[rustc_legacy_const_generics(1)]
46034603
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
46044604
pub fn _kshiftli_mask8<const COUNT: u32>(a: __mmask8) -> __mmask8 {
4605-
a << COUNT
4605+
a.unbounded_shl(COUNT)
46064606
}
46074607

46084608
/// Shift 8-bit mask a right by count bits while shifting in zeros, and store the result in dst.
@@ -4613,7 +4613,7 @@ pub fn _kshiftli_mask8<const COUNT: u32>(a: __mmask8) -> __mmask8 {
46134613
#[rustc_legacy_const_generics(1)]
46144614
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
46154615
pub fn _kshiftri_mask8<const COUNT: u32>(a: __mmask8) -> __mmask8 {
4616-
a >> COUNT
4616+
a.unbounded_shr(COUNT)
46174617
}
46184618

46194619
/// Compute the bitwise AND of 16-bit masks a and b, and if the result is all zeros, store 1 in dst,
@@ -9856,13 +9856,37 @@ mod tests {
98569856
let r = _kshiftli_mask8::<3>(a);
98579857
let e: __mmask8 = 0b01001000;
98589858
assert_eq!(r, e);
9859+
9860+
let r = _kshiftli_mask8::<7>(a);
9861+
let e: __mmask8 = 0b10000000;
9862+
assert_eq!(r, e);
9863+
9864+
let r = _kshiftli_mask8::<8>(a);
9865+
let e: __mmask8 = 0b00000000;
9866+
assert_eq!(r, e);
9867+
9868+
let r = _kshiftli_mask8::<9>(a);
9869+
let e: __mmask8 = 0b00000000;
9870+
assert_eq!(r, e);
98599871
}
98609872

98619873
#[simd_test(enable = "avx512dq")]
98629874
unsafe fn test_kshiftri_mask8() {
9863-
let a: __mmask8 = 0b01101001;
9875+
let a: __mmask8 = 0b10101001;
98649876
let r = _kshiftri_mask8::<3>(a);
9865-
let e: __mmask8 = 0b00001101;
9877+
let e: __mmask8 = 0b00010101;
9878+
assert_eq!(r, e);
9879+
9880+
let r = _kshiftri_mask8::<7>(a);
9881+
let e: __mmask8 = 0b00000001;
9882+
assert_eq!(r, e);
9883+
9884+
let r = _kshiftri_mask8::<8>(a);
9885+
let e: __mmask8 = 0b00000000;
9886+
assert_eq!(r, e);
9887+
9888+
let r = _kshiftri_mask8::<9>(a);
9889+
let e: __mmask8 = 0b00000000;
98669890
assert_eq!(r, e);
98679891
}
98689892

crates/core_arch/src/x86/avx512f.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29090,7 +29090,7 @@ pub fn _kortestz_mask16_u8(a: __mmask16, b: __mmask16) -> u8 {
2909029090
#[rustc_legacy_const_generics(1)]
2909129091
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2909229092
pub fn _kshiftli_mask16<const COUNT: u32>(a: __mmask16) -> __mmask16 {
29093-
a << COUNT
29093+
a.unbounded_shl(COUNT)
2909429094
}
2909529095

2909629096
/// Shift 16-bit mask a right by count bits while shifting in zeros, and store the result in dst.
@@ -29101,7 +29101,7 @@ pub fn _kshiftli_mask16<const COUNT: u32>(a: __mmask16) -> __mmask16 {
2910129101
#[rustc_legacy_const_generics(1)]
2910229102
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
2910329103
pub fn _kshiftri_mask16<const COUNT: u32>(a: __mmask16) -> __mmask16 {
29104-
a >> COUNT
29104+
a.unbounded_shr(COUNT)
2910529105
}
2910629106

2910729107
/// Load 16-bit mask from memory
@@ -56001,13 +56001,37 @@ mod tests {
5600156001
let r = _kshiftli_mask16::<3>(a);
5600256002
let e: __mmask16 = 0b1011011000011000;
5600356003
assert_eq!(r, e);
56004+
56005+
let r = _kshiftli_mask16::<15>(a);
56006+
let e: __mmask16 = 0b1000000000000000;
56007+
assert_eq!(r, e);
56008+
56009+
let r = _kshiftli_mask16::<16>(a);
56010+
let e: __mmask16 = 0b0000000000000000;
56011+
assert_eq!(r, e);
56012+
56013+
let r = _kshiftli_mask16::<17>(a);
56014+
let e: __mmask16 = 0b0000000000000000;
56015+
assert_eq!(r, e);
5600456016
}
5600556017

5600656018
#[simd_test(enable = "avx512dq")]
5600756019
unsafe fn test_kshiftri_mask16() {
56008-
let a: __mmask16 = 0b0110100100111100;
56020+
let a: __mmask16 = 0b1010100100111100;
5600956021
let r = _kshiftri_mask16::<3>(a);
56010-
let e: __mmask16 = 0b0000110100100111;
56022+
let e: __mmask16 = 0b0001010100100111;
56023+
assert_eq!(r, e);
56024+
56025+
let r = _kshiftri_mask16::<15>(a);
56026+
let e: __mmask16 = 0b0000000000000001;
56027+
assert_eq!(r, e);
56028+
56029+
let r = _kshiftri_mask16::<16>(a);
56030+
let e: __mmask16 = 0b0000000000000000;
56031+
assert_eq!(r, e);
56032+
56033+
let r = _kshiftri_mask16::<17>(a);
56034+
let e: __mmask16 = 0b0000000000000000;
5601156035
assert_eq!(r, e);
5601256036
}
5601356037

0 commit comments

Comments
 (0)