Skip to content

Commit 0697a43

Browse files
fix: update the implementation of _kshiftri_mask8 and _kshiftli_mask8 to
zero out when the amount of shift exceeds the bit length of the input argument.
1 parent 6e263ec commit 0697a43

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

crates/core_arch/src/x86/avx512dq.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4602,7 +4602,7 @@ pub fn _kortestz_mask8_u8(a: __mmask8, b: __mmask8) -> u8 {
46024602
#[rustc_legacy_const_generics(1)]
46034603
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
46044604
pub fn _kshiftli_mask8<const COUNT: u32>(a: __mmask8) -> __mmask8 {
4605-
a << COUNT
4605+
a.unbounded_shl(COUNT)
46064606
}
46074607

46084608
/// Shift 8-bit mask a right by count bits while shifting in zeros, and store the result in dst.
@@ -4613,7 +4613,7 @@ pub fn _kshiftli_mask8<const COUNT: u32>(a: __mmask8) -> __mmask8 {
46134613
#[rustc_legacy_const_generics(1)]
46144614
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
46154615
pub fn _kshiftri_mask8<const COUNT: u32>(a: __mmask8) -> __mmask8 {
4616-
a >> COUNT
4616+
a.unbounded_shr(COUNT)
46174617
}
46184618

46194619
/// Compute the bitwise AND of 16-bit masks a and b, and if the result is all zeros, store 1 in dst,
@@ -9856,13 +9856,37 @@ mod tests {
98569856
let r = _kshiftli_mask8::<3>(a);
98579857
let e: __mmask8 = 0b01001000;
98589858
assert_eq!(r, e);
9859+
9860+
let r = _kshiftli_mask8::<7>(a);
9861+
let e: __mmask8 = 0b10000000;
9862+
assert_eq!(r, e);
9863+
9864+
let r = _kshiftli_mask8::<8>(a);
9865+
let e: __mmask8 = 0b00000000;
9866+
assert_eq!(r, e);
9867+
9868+
let r = _kshiftli_mask8::<9>(a);
9869+
let e: __mmask8 = 0b00000000;
9870+
assert_eq!(r, e);
98599871
}
98609872

98619873
#[simd_test(enable = "avx512dq")]
98629874
unsafe fn test_kshiftri_mask8() {
9863-
let a: __mmask8 = 0b01101001;
9875+
let a: __mmask8 = 0b10101001;
98649876
let r = _kshiftri_mask8::<3>(a);
9865-
let e: __mmask8 = 0b00001101;
9877+
let e: __mmask8 = 0b00010101;
9878+
assert_eq!(r, e);
9879+
9880+
let r = _kshiftri_mask8::<7>(a);
9881+
let e: __mmask8 = 0b00000001;
9882+
assert_eq!(r, e);
9883+
9884+
let r = _kshiftri_mask8::<8>(a);
9885+
let e: __mmask8 = 0b00000000;
9886+
assert_eq!(r, e);
9887+
9888+
let r = _kshiftri_mask8::<9>(a);
9889+
let e: __mmask8 = 0b00000000;
98669890
assert_eq!(r, e);
98679891
}
98689892

0 commit comments

Comments
 (0)