Skip to content

Commit 4be142a

Browse files
fix: update the implementation of _kshiftri_mask32, _kshiftri_mask64,
_kshiftli_mask32 and _kshiftli_mask64 to zero out when the amount of shift exceeds the bit length of the input argument.
1 parent 6dfd5de commit 4be142a

File tree

1 file changed

+60
-8
lines changed

1 file changed

+60
-8
lines changed

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10440,7 +10440,7 @@ pub fn _kortestz_mask64_u8(a: __mmask64, b: __mmask64) -> u8 {
1044010440
#[rustc_legacy_const_generics(1)]
1044110441
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1044210442
pub fn _kshiftli_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
10443-
a << COUNT
10443+
a.unbounded_shl(COUNT)
1044410444
}
1044510445

1044610446
/// Shift the bits of 64-bit mask a left by count while shifting in zeros, and store the least significant 32 bits of the result in k.
@@ -10451,7 +10451,7 @@ pub fn _kshiftli_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
1045110451
#[rustc_legacy_const_generics(1)]
1045210452
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1045310453
pub fn _kshiftli_mask64<const COUNT: u32>(a: __mmask64) -> __mmask64 {
10454-
a << COUNT
10454+
a.unbounded_shl(COUNT)
1045510455
}
1045610456

1045710457
/// Shift the bits of 32-bit mask a right by count while shifting in zeros, and store the least significant 32 bits of the result in k.
@@ -10462,7 +10462,7 @@ pub fn _kshiftli_mask64<const COUNT: u32>(a: __mmask64) -> __mmask64 {
1046210462
#[rustc_legacy_const_generics(1)]
1046310463
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1046410464
pub fn _kshiftri_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
10465-
a >> COUNT
10465+
a.unbounded_shr(COUNT)
1046610466
}
1046710467

1046810468
/// Shift the bits of 64-bit mask a right by count while shifting in zeros, and store the least significant 32 bits of the result in k.
@@ -10473,7 +10473,7 @@ pub fn _kshiftri_mask32<const COUNT: u32>(a: __mmask32) -> __mmask32 {
1047310473
#[rustc_legacy_const_generics(1)]
1047410474
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
1047510475
pub fn _kshiftri_mask64<const COUNT: u32>(a: __mmask64) -> __mmask64 {
10476-
a >> COUNT
10476+
a.unbounded_shr(COUNT)
1047710477
}
1047810478

1047910479
/// Compute the bitwise AND of 32-bit masks a and b, and if the result is all zeros, store 1 in dst,
@@ -20315,6 +20315,18 @@ mod tests {
2031520315
let r = _kshiftli_mask32::<3>(a);
2031620316
let e: __mmask32 = 0b0100101101001011_0100101101001000;
2031720317
assert_eq!(r, e);
20318+
20319+
let r = _kshiftli_mask32::<31>(a);
20320+
let e: __mmask32 = 0b1000000000000000_0000000000000000;
20321+
assert_eq!(r, e);
20322+
20323+
let r = _kshiftli_mask32::<32>(a);
20324+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
20325+
assert_eq!(r, e);
20326+
20327+
let r = _kshiftli_mask32::<33>(a);
20328+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
20329+
assert_eq!(r, e);
2031820330
}
2031920331

2032020332
#[simd_test(enable = "avx512bw")]
@@ -20323,21 +20335,61 @@ mod tests {
2032320335
let r = _kshiftli_mask64::<3>(a);
2032420336
let e: __mmask64 = 0b0110100101101001011_0100101101001000;
2032520337
assert_eq!(r, e);
20338+
20339+
let r = _kshiftli_mask64::<63>(a);
20340+
let e: __mmask64 = 0b1000000000000000_0000000000000000_0000000000000000_0000000000000000;
20341+
assert_eq!(r, e);
20342+
20343+
let r = _kshiftli_mask64::<64>(a);
20344+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20345+
assert_eq!(r, e);
20346+
20347+
let r = _kshiftli_mask64::<65>(a);
20348+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20349+
assert_eq!(r, e);
2032620350
}
2032720351

2032820352
#[simd_test(enable = "avx512bw")]
2032920353
unsafe fn test_kshiftri_mask32() {
20330-
let a: __mmask32 = 0b0110100101101001_0110100101101001;
20354+
let a: __mmask32 = 0b1010100101101001_0110100101101001;
2033120355
let r = _kshiftri_mask32::<3>(a);
20332-
let e: __mmask32 = 0b0000110100101101_0010110100101101;
20356+
let e: __mmask32 = 0b0001010100101101_0010110100101101;
20357+
assert_eq!(r, e);
20358+
20359+
let r = _kshiftri_mask32::<31>(a);
20360+
let e: __mmask32 = 0b0000000000000000_0000000000000001;
20361+
assert_eq!(r, e);
20362+
20363+
let r = _kshiftri_mask32::<32>(a);
20364+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
20365+
assert_eq!(r, e);
20366+
20367+
let r = _kshiftri_mask32::<33>(a);
20368+
let e: __mmask32 = 0b0000000000000000_0000000000000000;
2033320369
assert_eq!(r, e);
2033420370
}
2033520371

2033620372
#[simd_test(enable = "avx512bw")]
2033720373
unsafe fn test_kshiftri_mask64() {
20338-
let a: __mmask64 = 0b0110100101101001011_0100101101001000;
20374+
let a: __mmask64 = 0b1010100101101001011_0100101101001000;
2033920375
let r = _kshiftri_mask64::<3>(a);
20340-
let e: __mmask64 = 0b0110100101101001_0110100101101001;
20376+
let e: __mmask64 = 0b1010100101101001_0110100101101001;
20377+
assert_eq!(r, e);
20378+
20379+
let r = _kshiftri_mask64::<34>(a);
20380+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000001;
20381+
assert_eq!(r, e);
20382+
20383+
let r = _kshiftri_mask64::<35>(a);
20384+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20385+
assert_eq!(r, e);
20386+
20387+
let r = _kshiftri_mask64::<64>(a);
20388+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
20389+
assert_eq!(r, e);
20390+
20391+
let r = _kshiftli_mask64::<65>(a);
20392+
let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000;
2034120393
assert_eq!(r, e);
2034220394
}
2034320395

0 commit comments

Comments
 (0)