Skip to content

Commit 1373f8b

Browse files
committed
Use SIMD intrinsics for f16 intrinsics
1 parent 1b82319 commit 1373f8b

File tree

2 files changed

+90
-18
lines changed

2 files changed

+90
-18
lines changed

crates/core_arch/src/x86/avx512fp16.rs

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1615,7 +1615,7 @@ pub fn _mm_maskz_add_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
16151615
#[cfg_attr(test, assert_instr(vaddsh))]
16161616
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
16171617
pub fn _mm_add_sh(a: __m128h, b: __m128h) -> __m128h {
1618-
_mm_add_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
1618+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) + _mm_cvtsh_h(b)) }
16191619
}
16201620

16211621
/// Add the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -1628,7 +1628,16 @@ pub fn _mm_add_sh(a: __m128h, b: __m128h) -> __m128h {
16281628
#[cfg_attr(test, assert_instr(vaddsh))]
16291629
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
16301630
pub fn _mm_mask_add_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1631-
_mm_mask_add_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
1631+
unsafe {
1632+
let extractsrc: f16 = simd_extract!(src, 0);
1633+
let mut add: f16 = extractsrc;
1634+
if (k & 0b00000001) != 0 {
1635+
let extracta: f16 = simd_extract!(a, 0);
1636+
let extractb: f16 = simd_extract!(b, 0);
1637+
add = extracta + extractb;
1638+
}
1639+
simd_insert!(a, 0, add)
1640+
}
16321641
}
16331642

16341643
/// Add the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -1641,7 +1650,15 @@ pub fn _mm_mask_add_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
16411650
#[cfg_attr(test, assert_instr(vaddsh))]
16421651
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
16431652
pub fn _mm_maskz_add_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1644-
_mm_maskz_add_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
1653+
unsafe {
1654+
let mut add: f16 = 0.;
1655+
if (k & 0b00000001) != 0 {
1656+
let extracta: f16 = simd_extract!(a, 0);
1657+
let extractb: f16 = simd_extract!(b, 0);
1658+
add = extracta + extractb;
1659+
}
1660+
simd_insert!(a, 0, add)
1661+
}
16451662
}
16461663

16471664
/// Subtract packed half-precision (16-bit) floating-point elements in b from a, and store the results in dst.
@@ -1927,7 +1944,7 @@ pub fn _mm_maskz_sub_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
19271944
#[cfg_attr(test, assert_instr(vsubsh))]
19281945
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
19291946
pub fn _mm_sub_sh(a: __m128h, b: __m128h) -> __m128h {
1930-
_mm_sub_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
1947+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) - _mm_cvtsh_h(b)) }
19311948
}
19321949

19331950
/// Subtract the lower half-precision (16-bit) floating-point elements in b from a, store the result in the
@@ -1940,7 +1957,16 @@ pub fn _mm_sub_sh(a: __m128h, b: __m128h) -> __m128h {
19401957
#[cfg_attr(test, assert_instr(vsubsh))]
19411958
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
19421959
pub fn _mm_mask_sub_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1943-
_mm_mask_sub_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
1960+
unsafe {
1961+
let extractsrc: f16 = simd_extract!(src, 0);
1962+
let mut add: f16 = extractsrc;
1963+
if (k & 0b00000001) != 0 {
1964+
let extracta: f16 = simd_extract!(a, 0);
1965+
let extractb: f16 = simd_extract!(b, 0);
1966+
add = extracta - extractb;
1967+
}
1968+
simd_insert!(a, 0, add)
1969+
}
19441970
}
19451971

19461972
/// Subtract the lower half-precision (16-bit) floating-point elements in b from a, store the result in the
@@ -1953,7 +1979,15 @@ pub fn _mm_mask_sub_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
19531979
#[cfg_attr(test, assert_instr(vsubsh))]
19541980
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
19551981
pub fn _mm_maskz_sub_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
1956-
_mm_maskz_sub_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
1982+
unsafe {
1983+
let mut add: f16 = 0.;
1984+
if (k & 0b00000001) != 0 {
1985+
let extracta: f16 = simd_extract!(a, 0);
1986+
let extractb: f16 = simd_extract!(b, 0);
1987+
add = extracta - extractb;
1988+
}
1989+
simd_insert!(a, 0, add)
1990+
}
19571991
}
19581992

19591993
/// Multiply packed half-precision (16-bit) floating-point elements in a and b, and store the results in dst.
@@ -2239,7 +2273,7 @@ pub fn _mm_maskz_mul_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
22392273
#[cfg_attr(test, assert_instr(vmulsh))]
22402274
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
22412275
pub fn _mm_mul_sh(a: __m128h, b: __m128h) -> __m128h {
2242-
_mm_mul_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
2276+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) * _mm_cvtsh_h(b)) }
22432277
}
22442278

22452279
/// Multiply the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -2252,7 +2286,16 @@ pub fn _mm_mul_sh(a: __m128h, b: __m128h) -> __m128h {
22522286
#[cfg_attr(test, assert_instr(vmulsh))]
22532287
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
22542288
pub fn _mm_mask_mul_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2255-
_mm_mask_mul_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
2289+
unsafe {
2290+
let extractsrc: f16 = simd_extract!(src, 0);
2291+
let mut add: f16 = extractsrc;
2292+
if (k & 0b00000001) != 0 {
2293+
let extracta: f16 = simd_extract!(a, 0);
2294+
let extractb: f16 = simd_extract!(b, 0);
2295+
add = extracta * extractb;
2296+
}
2297+
simd_insert!(a, 0, add)
2298+
}
22562299
}
22572300

22582301
/// Multiply the lower half-precision (16-bit) floating-point elements in a and b, store the result in the
@@ -2265,7 +2308,15 @@ pub fn _mm_mask_mul_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
22652308
#[cfg_attr(test, assert_instr(vmulsh))]
22662309
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
22672310
pub fn _mm_maskz_mul_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2268-
_mm_maskz_mul_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
2311+
unsafe {
2312+
let mut add: f16 = 0.;
2313+
if (k & 0b00000001) != 0 {
2314+
let extracta: f16 = simd_extract!(a, 0);
2315+
let extractb: f16 = simd_extract!(b, 0);
2316+
add = extracta * extractb;
2317+
}
2318+
simd_insert!(a, 0, add)
2319+
}
22692320
}
22702321

22712322
/// Divide packed half-precision (16-bit) floating-point elements in a by b, and store the results in dst.
@@ -2551,7 +2602,7 @@ pub fn _mm_maskz_div_round_sh<const ROUNDING: i32>(k: __mmask8, a: __m128h, b: _
25512602
#[cfg_attr(test, assert_instr(vdivsh))]
25522603
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
25532604
pub fn _mm_div_sh(a: __m128h, b: __m128h) -> __m128h {
2554-
_mm_div_round_sh::<_MM_FROUND_CUR_DIRECTION>(a, b)
2605+
unsafe { simd_insert!(a, 0, _mm_cvtsh_h(a) / _mm_cvtsh_h(b)) }
25552606
}
25562607

25572608
/// Divide the lower half-precision (16-bit) floating-point elements in a by b, store the result in the
@@ -2564,7 +2615,16 @@ pub fn _mm_div_sh(a: __m128h, b: __m128h) -> __m128h {
25642615
#[cfg_attr(test, assert_instr(vdivsh))]
25652616
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
25662617
pub fn _mm_mask_div_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2567-
_mm_mask_div_round_sh::<_MM_FROUND_CUR_DIRECTION>(src, k, a, b)
2618+
unsafe {
2619+
let extractsrc: f16 = simd_extract!(src, 0);
2620+
let mut add: f16 = extractsrc;
2621+
if (k & 0b00000001) != 0 {
2622+
let extracta: f16 = simd_extract!(a, 0);
2623+
let extractb: f16 = simd_extract!(b, 0);
2624+
add = extracta / extractb;
2625+
}
2626+
simd_insert!(a, 0, add)
2627+
}
25682628
}
25692629

25702630
/// Divide the lower half-precision (16-bit) floating-point elements in a by b, store the result in the
@@ -2577,7 +2637,15 @@ pub fn _mm_mask_div_sh(src: __m128h, k: __mmask8, a: __m128h, b: __m128h) -> __m
25772637
#[cfg_attr(test, assert_instr(vdivsh))]
25782638
#[unstable(feature = "stdarch_x86_avx512_f16", issue = "127213")]
25792639
pub fn _mm_maskz_div_sh(k: __mmask8, a: __m128h, b: __m128h) -> __m128h {
2580-
_mm_maskz_div_round_sh::<_MM_FROUND_CUR_DIRECTION>(k, a, b)
2640+
unsafe {
2641+
let mut add: f16 = 0.;
2642+
if (k & 0b00000001) != 0 {
2643+
let extracta: f16 = simd_extract!(a, 0);
2644+
let extractb: f16 = simd_extract!(b, 0);
2645+
add = extracta / extractb;
2646+
}
2647+
simd_insert!(a, 0, add)
2648+
}
25812649
}
25822650

25832651
/// Multiply packed complex numbers in a and b, and store the results in dst. Each complex number is

crates/core_arch/src/x86/f16c.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,13 @@
33
//! [F16C intrinsics]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=fp16&expand=1769
44
55
use crate::core_arch::{simd::*, x86::*};
6+
use crate::intrinsics::simd::*;
67

78
#[cfg(test)]
89
use stdarch_test::assert_instr;
910

1011
#[allow(improper_ctypes)]
1112
unsafe extern "unadjusted" {
12-
#[link_name = "llvm.x86.vcvtph2ps.128"]
13-
fn llvm_vcvtph2ps_128(a: i16x8) -> f32x4;
14-
#[link_name = "llvm.x86.vcvtph2ps.256"]
15-
fn llvm_vcvtph2ps_256(a: i16x8) -> f32x8;
1613
#[link_name = "llvm.x86.vcvtps2ph.128"]
1714
fn llvm_vcvtps2ph_128(a: f32x4, rounding: i32) -> i16x8;
1815
#[link_name = "llvm.x86.vcvtps2ph.256"]
@@ -29,7 +26,11 @@ unsafe extern "unadjusted" {
2926
#[cfg_attr(test, assert_instr("vcvtph2ps"))]
3027
#[stable(feature = "x86_f16c_intrinsics", since = "1.68.0")]
3128
pub fn _mm_cvtph_ps(a: __m128i) -> __m128 {
32-
unsafe { transmute(llvm_vcvtph2ps_128(transmute(a))) }
29+
unsafe {
30+
let a: f16x8 = transmute(a);
31+
let a: f16x4 = simd_shuffle!(a, a, [0, 1, 2, 3]);
32+
simd_cast(a)
33+
}
3334
}
3435

3536
/// Converts the 8 x 16-bit half-precision float values in the 128-bit vector
@@ -41,7 +42,10 @@ pub fn _mm_cvtph_ps(a: __m128i) -> __m128 {
4142
#[cfg_attr(test, assert_instr("vcvtph2ps"))]
4243
#[stable(feature = "x86_f16c_intrinsics", since = "1.68.0")]
4344
pub fn _mm256_cvtph_ps(a: __m128i) -> __m256 {
44-
unsafe { transmute(llvm_vcvtph2ps_256(transmute(a))) }
45+
unsafe {
46+
let a: f16x8 = transmute(a);
47+
simd_cast(a)
48+
}
4549
}
4650

4751
/// Converts the 4 x 32-bit float values in the 128-bit vector `a` into 4 x

0 commit comments

Comments
 (0)