Skip to content

Commit 3bb41b2

Browse files
committed
add an avx512 psad shim
also combine the sse2 and avx2 version into one generic function for all 3
1 parent 4894162 commit 3bb41b2

File tree

5 files changed

+98
-64
lines changed

5 files changed

+98
-64
lines changed

src/shims/x86/avx2.rs

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
66

77
use super::{
88
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
9-
packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd,
9+
packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
1010
};
1111
use crate::*;
1212

@@ -241,41 +241,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
241241
}
242242
}
243243
// Used to implement the _mm256_sad_epu8 function.
244-
// Compute the absolute differences of packed unsigned 8-bit integers
245-
// in `left` and `right`, then horizontally sum each consecutive 8
246-
// differences to produce four unsigned 16-bit integers, and pack
247-
// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
248-
// in `dest`.
249-
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8
250244
"psad.bw" => {
251245
let [left, right] =
252246
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
253247

254-
let (left, left_len) = this.project_to_simd(left)?;
255-
let (right, right_len) = this.project_to_simd(right)?;
256-
let (dest, dest_len) = this.project_to_simd(dest)?;
257-
258-
assert_eq!(left_len, right_len);
259-
assert_eq!(left_len, dest_len.strict_mul(8));
260-
261-
for i in 0..dest_len {
262-
let dest = this.project_index(&dest, i)?;
263-
264-
let mut acc: u16 = 0;
265-
for j in 0..8 {
266-
let src_index = i.strict_mul(8).strict_add(j);
267-
268-
let left = this.project_index(&left, src_index)?;
269-
let left = this.read_scalar(&left)?.to_u8()?;
270-
271-
let right = this.project_index(&right, src_index)?;
272-
let right = this.read_scalar(&right)?.to_u8()?;
273-
274-
acc = acc.strict_add(left.abs_diff(right).into());
275-
}
276-
277-
this.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
278-
}
248+
psadbw(this, left, right, dest)?
279249
}
280250
// Used to implement the _mm256_shuffle_epi8 intrinsic.
281251
// Shuffles bytes from `left` using `right` as pattern.

src/shims/x86/avx512.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use rustc_middle::ty::Ty;
33
use rustc_span::Symbol;
44
use rustc_target::callconv::FnAbi;
55

6+
use super::psadbw;
67
use crate::*;
78

89
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -78,6 +79,15 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
7879
this.write_scalar(Scalar::from_u32(r), &d_lane)?;
7980
}
8081
}
82+
// Used to implement the _mm512_sad_epu8 function.
83+
"psad.bw.512" => {
84+
this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
85+
86+
let [left, right] =
87+
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
88+
89+
psadbw(this, left, right, dest)?
90+
}
8191
_ => return interp_ok(EmulateItemResult::NotSupported),
8292
}
8393
interp_ok(EmulateItemResult::NeedsReturn)

src/shims/x86/mod.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,54 @@ fn mpsadbw<'tcx>(
10381038
interp_ok(())
10391039
}
10401040

1041+
/// Compute the absolute differences of packed unsigned 8-bit integers
1042+
/// in `left` and `right`, then horizontally sum each consecutive 8
1043+
/// differences to produce unsigned 16-bit integers, and pack
1044+
/// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
1045+
/// in `dest`.
1046+
///
1047+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8>
1048+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8>
1049+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_sad_epu8>
1050+
fn psadbw<'tcx>(
1051+
ecx: &mut crate::MiriInterpCx<'tcx>,
1052+
left: &OpTy<'tcx>,
1053+
right: &OpTy<'tcx>,
1054+
dest: &MPlaceTy<'tcx>,
1055+
) -> InterpResult<'tcx, ()> {
1056+
let (left, left_len) = ecx.project_to_simd(left)?;
1057+
let (right, right_len) = ecx.project_to_simd(right)?;
1058+
let (dest, dest_len) = ecx.project_to_simd(dest)?;
1059+
1060+
// fn psadbw(a: u8x16, b: u8x16) -> u64x2;
1061+
// fn psadbw(a: u8x32, b: u8x32) -> u64x4;
1062+
// fn vpsadbw(a: u8x64, b: u8x64) -> u64x8;
1063+
assert_eq!(left_len, right_len);
1064+
assert_eq!(left_len, left.layout.layout.size().bytes());
1065+
assert_eq!(dest_len, left_len.strict_div(8));
1066+
1067+
for i in 0..dest_len {
1068+
let dest = ecx.project_index(&dest, i)?;
1069+
1070+
let mut acc: u16 = 0;
1071+
for j in 0..8 {
1072+
let src_index = i.strict_mul(8).strict_add(j);
1073+
1074+
let left = ecx.project_index(&left, src_index)?;
1075+
let left = ecx.read_scalar(&left)?.to_u8()?;
1076+
1077+
let right = ecx.project_index(&right, src_index)?;
1078+
let right = ecx.read_scalar(&right)?.to_u8()?;
1079+
1080+
acc = acc.strict_add(left.abs_diff(right).into());
1081+
}
1082+
1083+
ecx.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
1084+
}
1085+
1086+
interp_ok(())
1087+
}
1088+
10411089
/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
10421090
/// product to the 18 most significant bits by right-shifting, and then
10431091
/// divides the 18-bit value by 2 (rounding to nearest) by first adding

src/shims/x86/sse2.rs

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
66

77
use super::{
88
FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
9-
packssdw, packsswb, packuswb, shift_simd_by_scalar,
9+
packssdw, packsswb, packuswb, psadbw, shift_simd_by_scalar,
1010
};
1111
use crate::*;
1212

@@ -37,41 +37,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
3737
// vectors.
3838
match unprefixed_name {
3939
// Used to implement the _mm_sad_epu8 function.
40-
// Computes the absolute differences of packed unsigned 8-bit integers in `a`
41-
// and `b`, then horizontally sum each consecutive 8 differences to produce
42-
// two unsigned 16-bit integers, and pack these unsigned 16-bit integers in
43-
// the low 16 bits of 64-bit elements returned.
44-
//
45-
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8
4640
"psad.bw" => {
4741
let [left, right] =
4842
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
4943

50-
let (left, left_len) = this.project_to_simd(left)?;
51-
let (right, right_len) = this.project_to_simd(right)?;
52-
let (dest, dest_len) = this.project_to_simd(dest)?;
53-
54-
// left and right are u8x16, dest is u64x2
55-
assert_eq!(left_len, right_len);
56-
assert_eq!(left_len, 16);
57-
assert_eq!(dest_len, 2);
58-
59-
for i in 0..dest_len {
60-
let dest = this.project_index(&dest, i)?;
61-
62-
let mut res: u16 = 0;
63-
let n = left_len.strict_div(dest_len);
64-
for j in 0..n {
65-
let op_i = j.strict_add(i.strict_mul(n));
66-
let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?;
67-
let right =
68-
this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?;
69-
70-
res = res.strict_add(left.abs_diff(right).into());
71-
}
72-
73-
this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
74-
}
44+
psadbw(this, left, right, dest)?
7545
}
7646
// Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
7747
// (except _mm_sra_epi64, which is not available in SSE2).

tests/pass/shims/x86/intrinsics-x86-avx512.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,48 @@ fn main() {
1515
assert!(is_x86_feature_detected!("avx512vpopcntdq"));
1616

1717
unsafe {
18+
test_avx512();
1819
test_avx512bitalg();
1920
test_avx512vpopcntdq();
2021
test_avx512ternarylogic();
2122
}
2223
}
2324

25+
#[target_feature(enable = "avx512bw")]
26+
unsafe fn test_avx512() {
27+
#[target_feature(enable = "avx512bw")]
28+
unsafe fn test_mm512_sad_epu8() {
29+
let a = _mm512_set_epi8(
30+
71, 70, 69, 68, 67, 66, 65, 64, //
31+
55, 54, 53, 52, 51, 50, 49, 48, //
32+
47, 46, 45, 44, 43, 42, 41, 40, //
33+
39, 38, 37, 36, 35, 34, 33, 32, //
34+
31, 30, 29, 28, 27, 26, 25, 24, //
35+
23, 22, 21, 20, 19, 18, 17, 16, //
36+
15, 14, 13, 12, 11, 10, 9, 8, //
37+
7, 6, 5, 4, 3, 2, 1, 0, //
38+
);
39+
40+
// `d` is the absolute difference with the corresponding row in `a`.
41+
let b = _mm512_set_epi8(
42+
63, 62, 61, 60, 59, 58, 57, 56, // lane 7 (d = 8)
43+
62, 61, 60, 59, 58, 57, 56, 55, // lane 6 (d = 7)
44+
53, 52, 51, 50, 49, 48, 47, 46, // lane 5 (d = 6)
45+
44, 43, 42, 41, 40, 39, 38, 37, // lane 4 (d = 5)
46+
35, 34, 33, 32, 31, 30, 29, 28, // lane 3 (d = 4)
47+
26, 25, 24, 23, 22, 21, 20, 19, // lane 2 (d = 3)
48+
17, 16, 15, 14, 13, 12, 11, 10, // lane 1 (d = 2)
49+
8, 7, 6, 5, 4, 3, 2, 1, // lane 0 (d = 1)
50+
);
51+
52+
let r = _mm512_sad_epu8(a, b);
53+
let e = _mm512_set_epi64(64, 56, 48, 40, 32, 24, 16, 8);
54+
55+
assert_eq_m512i(r, e);
56+
}
57+
test_mm512_sad_epu8();
58+
}
59+
2460
// Some of the constants in the tests below are just bit patterns. They should not
2561
// be interpreted as integers; signedness does not make sense for them, but
2662
// __mXXXi happens to be defined in terms of signed integers.

0 commit comments

Comments
 (0)