From 6e263ec133215ac7837b0d4e48e4d0989e2eab4c Mon Sep 17 00:00:00 2001 From: Madhav Madhusoodanan Date: Fri, 3 Oct 2025 02:20:50 +0530 Subject: [PATCH 1/3] fix: update the implementation of _kshiftri_mask32, _kshiftri_mask64, _kshiftli_mask32 and _kshiftli_mask64 to zero out when the amount of shift exceeds the bit length of the input argument. --- crates/core_arch/src/x86/avx512bw.rs | 68 ++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 8 deletions(-) diff --git a/crates/core_arch/src/x86/avx512bw.rs b/crates/core_arch/src/x86/avx512bw.rs index 1771f19659..094b89f3ac 100644 --- a/crates/core_arch/src/x86/avx512bw.rs +++ b/crates/core_arch/src/x86/avx512bw.rs @@ -10440,7 +10440,7 @@ pub fn _kortestz_mask64_u8(a: __mmask64, b: __mmask64) -> u8 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftli_mask32(a: __mmask32) -> __mmask32 { - a << COUNT + a.unbounded_shl(COUNT) } /// Shift the bits of 64-bit mask a left by count while shifting in zeros, and store the least significant 32 bits of the result in k. @@ -10451,7 +10451,7 @@ pub fn _kshiftli_mask32(a: __mmask32) -> __mmask32 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftli_mask64(a: __mmask64) -> __mmask64 { - a << COUNT + a.unbounded_shl(COUNT) } /// Shift the bits of 32-bit mask a right by count while shifting in zeros, and store the least significant 32 bits of the result in k. @@ -10462,7 +10462,7 @@ pub fn _kshiftli_mask64(a: __mmask64) -> __mmask64 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftri_mask32(a: __mmask32) -> __mmask32 { - a >> COUNT + a.unbounded_shr(COUNT) } /// Shift the bits of 64-bit mask a right by count while shifting in zeros, and store the least significant 32 bits of the result in k. @@ -10473,7 +10473,7 @@ pub fn _kshiftri_mask32(a: __mmask32) -> __mmask32 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftri_mask64(a: __mmask64) -> __mmask64 { - a >> COUNT + a.unbounded_shr(COUNT) } /// Compute the bitwise AND of 32-bit masks a and b, and if the result is all zeros, store 1 in dst, @@ -20315,6 +20315,18 @@ mod tests { let r = _kshiftli_mask32::<3>(a); let e: __mmask32 = 0b0100101101001011_0100101101001000; assert_eq!(r, e); + + let r = _kshiftli_mask32::<31>(a); + let e: __mmask32 = 0b1000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftli_mask32::<32>(a); + let e: __mmask32 = 0b0000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftli_mask32::<33>(a); + let e: __mmask32 = 0b0000000000000000_0000000000000000; + assert_eq!(r, e); } #[simd_test(enable = "avx512bw")] @@ -20323,21 +20335,61 @@ mod tests { let r = _kshiftli_mask64::<3>(a); let e: __mmask64 = 0b0110100101101001011_0100101101001000; assert_eq!(r, e); + + let r = _kshiftli_mask64::<63>(a); + let e: __mmask64 = 0b1000000000000000_0000000000000000_0000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftli_mask64::<64>(a); + let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftli_mask64::<65>(a); + let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000; + assert_eq!(r, e); } #[simd_test(enable = "avx512bw")] unsafe fn test_kshiftri_mask32() { - let a: __mmask32 = 0b0110100101101001_0110100101101001; + let a: __mmask32 = 0b1010100101101001_0110100101101001; let r = _kshiftri_mask32::<3>(a); - let e: __mmask32 = 0b0000110100101101_0010110100101101; + let e: __mmask32 = 0b0001010100101101_0010110100101101; + assert_eq!(r, e); + + let r = _kshiftri_mask32::<31>(a); + let e: __mmask32 = 0b0000000000000000_0000000000000001; + assert_eq!(r, e); + + let r = _kshiftri_mask32::<32>(a); + let e: __mmask32 = 0b0000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftri_mask32::<33>(a); + let e: __mmask32 = 0b0000000000000000_0000000000000000; assert_eq!(r, e); } #[simd_test(enable = "avx512bw")] unsafe fn test_kshiftri_mask64() { - let a: __mmask64 = 0b0110100101101001011_0100101101001000; + let a: __mmask64 = 0b1010100101101001011_0100101101001000; let r = _kshiftri_mask64::<3>(a); - let e: __mmask64 = 0b0110100101101001_0110100101101001; + let e: __mmask64 = 0b1010100101101001_0110100101101001; + assert_eq!(r, e); + + let r = _kshiftri_mask64::<34>(a); + let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000001; + assert_eq!(r, e); + + let r = _kshiftri_mask64::<35>(a); + let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftri_mask64::<64>(a); + let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000; + assert_eq!(r, e); + + let r = _kshiftri_mask64::<65>(a); + let e: __mmask64 = 0b0000000000000000_0000000000000000_0000000000000000_0000000000000000; assert_eq!(r, e); } From 0697a436fd7bb27a55150c04b948e25cb8ca65d9 Mon Sep 17 00:00:00 2001 From: Madhav Madhusoodanan Date: Fri, 3 Oct 2025 02:27:15 +0530 Subject: [PATCH 2/3] fix: update the implementation of _kshiftri_mask8 and _kshiftli_mask8 to zero out when the amount of shift exceeds the bit length of the input argument. --- crates/core_arch/src/x86/avx512dq.rs | 32 ++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/crates/core_arch/src/x86/avx512dq.rs b/crates/core_arch/src/x86/avx512dq.rs index c90ec894f2..afeb548a55 100644 --- a/crates/core_arch/src/x86/avx512dq.rs +++ b/crates/core_arch/src/x86/avx512dq.rs @@ -4602,7 +4602,7 @@ pub fn _kortestz_mask8_u8(a: __mmask8, b: __mmask8) -> u8 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftli_mask8(a: __mmask8) -> __mmask8 { - a << COUNT + a.unbounded_shl(COUNT) } /// Shift 8-bit mask a right by count bits while shifting in zeros, and store the result in dst. @@ -4613,7 +4613,7 @@ pub fn _kshiftli_mask8(a: __mmask8) -> __mmask8 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftri_mask8(a: __mmask8) -> __mmask8 { - a >> COUNT + a.unbounded_shr(COUNT) } /// Compute the bitwise AND of 16-bit masks a and b, and if the result is all zeros, store 1 in dst, @@ -9856,13 +9856,37 @@ mod tests { let r = _kshiftli_mask8::<3>(a); let e: __mmask8 = 0b01001000; assert_eq!(r, e); + + let r = _kshiftli_mask8::<7>(a); + let e: __mmask8 = 0b10000000; + assert_eq!(r, e); + + let r = _kshiftli_mask8::<8>(a); + let e: __mmask8 = 0b00000000; + assert_eq!(r, e); + + let r = _kshiftli_mask8::<9>(a); + let e: __mmask8 = 0b00000000; + assert_eq!(r, e); } #[simd_test(enable = "avx512dq")] unsafe fn test_kshiftri_mask8() { - let a: __mmask8 = 0b01101001; + let a: __mmask8 = 0b10101001; let r = _kshiftri_mask8::<3>(a); - let e: __mmask8 = 0b00001101; + let e: __mmask8 = 0b00010101; + assert_eq!(r, e); + + let r = _kshiftri_mask8::<7>(a); + let e: __mmask8 = 0b00000001; + assert_eq!(r, e); + + let r = _kshiftri_mask8::<8>(a); + let e: __mmask8 = 0b00000000; + assert_eq!(r, e); + + let r = _kshiftri_mask8::<9>(a); + let e: __mmask8 = 0b00000000; assert_eq!(r, e); } From 29027b6f07171a718ddbbdafe185f3c5c95c4f4c Mon Sep 17 00:00:00 2001 From: Madhav Madhusoodanan Date: Fri, 3 Oct 2025 02:33:11 +0530 Subject: [PATCH 3/3] fix: update the implementation of _kshiftri_mask16 and _kshiftli_mask16 to zero out when the amount of shift exceeds 16. --- crates/core_arch/src/x86/avx512f.rs | 32 +++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/crates/core_arch/src/x86/avx512f.rs b/crates/core_arch/src/x86/avx512f.rs index 002534a65d..001b877812 100644 --- a/crates/core_arch/src/x86/avx512f.rs +++ b/crates/core_arch/src/x86/avx512f.rs @@ -29090,7 +29090,7 @@ pub fn _kortestz_mask16_u8(a: __mmask16, b: __mmask16) -> u8 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftli_mask16(a: __mmask16) -> __mmask16 { - a << COUNT + a.unbounded_shl(COUNT) } /// Shift 16-bit mask a right by count bits while shifting in zeros, and store the result in dst. @@ -29101,7 +29101,7 @@ pub fn _kshiftli_mask16(a: __mmask16) -> __mmask16 { #[rustc_legacy_const_generics(1)] #[stable(feature = "stdarch_x86_avx512", since = "1.89")] pub fn _kshiftri_mask16(a: __mmask16) -> __mmask16 { - a >> COUNT + a.unbounded_shr(COUNT) } /// Load 16-bit mask from memory @@ -56001,13 +56001,37 @@ mod tests { let r = _kshiftli_mask16::<3>(a); let e: __mmask16 = 0b1011011000011000; assert_eq!(r, e); + + let r = _kshiftli_mask16::<15>(a); + let e: __mmask16 = 0b1000000000000000; + assert_eq!(r, e); + + let r = _kshiftli_mask16::<16>(a); + let e: __mmask16 = 0b0000000000000000; + assert_eq!(r, e); + + let r = _kshiftli_mask16::<17>(a); + let e: __mmask16 = 0b0000000000000000; + assert_eq!(r, e); } #[simd_test(enable = "avx512dq")] unsafe fn test_kshiftri_mask16() { - let a: __mmask16 = 0b0110100100111100; + let a: __mmask16 = 0b1010100100111100; let r = _kshiftri_mask16::<3>(a); - let e: __mmask16 = 0b0000110100100111; + let e: __mmask16 = 0b0001010100100111; + assert_eq!(r, e); + + let r = _kshiftri_mask16::<15>(a); + let e: __mmask16 = 0b0000000000000001; + assert_eq!(r, e); + + let r = _kshiftri_mask16::<16>(a); + let e: __mmask16 = 0b0000000000000000; + assert_eq!(r, e); + + let r = _kshiftri_mask16::<17>(a); + let e: __mmask16 = 0b0000000000000000; assert_eq!(r, e); }