Skip to content

Commit a5c75dc

Browse files
TDeckingAmanieu
authored andcommitted
Update fma.rs
1 parent d2b1a07 commit a5c75dc

File tree

1 file changed

+52
-60
lines changed
  • crates/core_arch/src/x86

1 file changed

+52
-60
lines changed

crates/core_arch/src/x86/fma.rs

Lines changed: 52 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,11 @@ pub unsafe fn _mm256_fmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
8383
#[cfg_attr(test, assert_instr(vfmadd))]
8484
#[stable(feature = "simd_x86", since = "1.27.0")]
8585
pub unsafe fn _mm_fmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
86-
vfmaddsd(a, b, c)
86+
simd_insert!(
87+
a,
88+
0,
89+
_mm_cvtsd_f64(a).mul_add(_mm_cvtsd_f64(b), _mm_cvtsd_f64(c))
90+
)
8791
}
8892

8993
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@@ -97,7 +101,11 @@ pub unsafe fn _mm_fmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
97101
#[cfg_attr(test, assert_instr(vfmadd))]
98102
#[stable(feature = "simd_x86", since = "1.27.0")]
99103
pub unsafe fn _mm_fmadd_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
100-
vfmaddss(a, b, c)
104+
simd_insert!(
105+
a,
106+
0,
107+
_mm_cvtss_f32(a).mul_add(_mm_cvtss_f32(b), _mm_cvtss_f32(c))
108+
)
101109
}
102110

103111
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@@ -161,7 +169,7 @@ pub unsafe fn _mm256_fmaddsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
161169
#[cfg_attr(test, assert_instr(vfmsub))]
162170
#[stable(feature = "simd_x86", since = "1.27.0")]
163171
pub unsafe fn _mm_fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
164-
vfmsubpd(a, b, c)
172+
simd_fma(a, b, simd_neg(c))
165173
}
166174

167175
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@@ -173,7 +181,7 @@ pub unsafe fn _mm_fmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
173181
#[cfg_attr(test, assert_instr(vfmsub))]
174182
#[stable(feature = "simd_x86", since = "1.27.0")]
175183
pub unsafe fn _mm256_fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
176-
vfmsubpd256(a, b, c)
184+
simd_fma(a, b, simd_neg(c))
177185
}
178186

179187
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@@ -185,7 +193,7 @@ pub unsafe fn _mm256_fmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
185193
#[cfg_attr(test, assert_instr(vfmsub213ps))]
186194
#[stable(feature = "simd_x86", since = "1.27.0")]
187195
pub unsafe fn _mm_fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
188-
vfmsubps(a, b, c)
196+
simd_fma(a, b, simd_neg(c))
189197
}
190198

191199
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@@ -197,7 +205,7 @@ pub unsafe fn _mm_fmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
197205
#[cfg_attr(test, assert_instr(vfmsub213ps))]
198206
#[stable(feature = "simd_x86", since = "1.27.0")]
199207
pub unsafe fn _mm256_fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
200-
vfmsubps256(a, b, c)
208+
simd_fma(a, b, simd_neg(c))
201209
}
202210

203211
/// Multiplies the lower double-precision (64-bit) floating-point elements in
@@ -211,7 +219,11 @@ pub unsafe fn _mm256_fmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
211219
#[cfg_attr(test, assert_instr(vfmsub))]
212220
#[stable(feature = "simd_x86", since = "1.27.0")]
213221
pub unsafe fn _mm_fmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
214-
vfmsubsd(a, b, c)
222+
simd_insert!(
223+
a,
224+
0,
225+
_mm_cvtsd_f64(a).mul_add(_mm_cvtsd_f64(b), -_mm_cvtsd_f64(c))
226+
)
215227
}
216228

217229
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@@ -225,7 +237,11 @@ pub unsafe fn _mm_fmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
225237
#[cfg_attr(test, assert_instr(vfmsub))]
226238
#[stable(feature = "simd_x86", since = "1.27.0")]
227239
pub unsafe fn _mm_fmsub_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
228-
vfmsubss(a, b, c)
240+
simd_insert!(
241+
a,
242+
0,
243+
_mm_cvtss_f32(a).mul_add(_mm_cvtss_f32(b), -_mm_cvtss_f32(c))
244+
)
229245
}
230246

231247
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@@ -289,7 +305,7 @@ pub unsafe fn _mm256_fmsubadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
289305
#[cfg_attr(test, assert_instr(vfnmadd))]
290306
#[stable(feature = "simd_x86", since = "1.27.0")]
291307
pub unsafe fn _mm_fnmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
292-
vfnmaddpd(a, b, c)
308+
simd_fma(simd_neg(a), b, c)
293309
}
294310

295311
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@@ -301,7 +317,7 @@ pub unsafe fn _mm_fnmadd_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
301317
#[cfg_attr(test, assert_instr(vfnmadd))]
302318
#[stable(feature = "simd_x86", since = "1.27.0")]
303319
pub unsafe fn _mm256_fnmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
304-
vfnmaddpd256(a, b, c)
320+
simd_fma(simd_neg(a), b, c)
305321
}
306322

307323
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@@ -313,7 +329,7 @@ pub unsafe fn _mm256_fnmadd_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
313329
#[cfg_attr(test, assert_instr(vfnmadd))]
314330
#[stable(feature = "simd_x86", since = "1.27.0")]
315331
pub unsafe fn _mm_fnmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
316-
vfnmaddps(a, b, c)
332+
simd_fma(simd_neg(a), b, c)
317333
}
318334

319335
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@@ -325,7 +341,7 @@ pub unsafe fn _mm_fnmadd_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
325341
#[cfg_attr(test, assert_instr(vfnmadd))]
326342
#[stable(feature = "simd_x86", since = "1.27.0")]
327343
pub unsafe fn _mm256_fnmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
328-
vfnmaddps256(a, b, c)
344+
simd_fma(simd_neg(a), b, c)
329345
}
330346

331347
/// Multiplies the lower double-precision (64-bit) floating-point elements in
@@ -339,7 +355,11 @@ pub unsafe fn _mm256_fnmadd_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
339355
#[cfg_attr(test, assert_instr(vfnmadd))]
340356
#[stable(feature = "simd_x86", since = "1.27.0")]
341357
pub unsafe fn _mm_fnmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
342-
vfnmaddsd(a, b, c)
358+
simd_insert!(
359+
a,
360+
0,
361+
_mm_cvtsd_f64(a).mul_add(-_mm_cvtsd_f64(b), _mm_cvtsd_f64(c))
362+
)
343363
}
344364

345365
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@@ -353,7 +373,11 @@ pub unsafe fn _mm_fnmadd_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
353373
#[cfg_attr(test, assert_instr(vfnmadd))]
354374
#[stable(feature = "simd_x86", since = "1.27.0")]
355375
pub unsafe fn _mm_fnmadd_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
356-
vfnmaddss(a, b, c)
376+
simd_insert!(
377+
a,
378+
0,
379+
_mm_cvtss_f32(a).mul_add(-_mm_cvtss_f32(b), _mm_cvtss_f32(c))
380+
)
357381
}
358382

359383
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@@ -366,7 +390,7 @@ pub unsafe fn _mm_fnmadd_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
366390
#[cfg_attr(test, assert_instr(vfnmsub))]
367391
#[stable(feature = "simd_x86", since = "1.27.0")]
368392
pub unsafe fn _mm_fnmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
369-
vfnmsubpd(a, b, c)
393+
simd_fma(simd_neg(a), b, simd_neg(c))
370394
}
371395

372396
/// Multiplies packed double-precision (64-bit) floating-point elements in `a`
@@ -379,7 +403,7 @@ pub unsafe fn _mm_fnmsub_pd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
379403
#[cfg_attr(test, assert_instr(vfnmsub))]
380404
#[stable(feature = "simd_x86", since = "1.27.0")]
381405
pub unsafe fn _mm256_fnmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
382-
vfnmsubpd256(a, b, c)
406+
simd_fma(simd_neg(a), b, simd_neg(c))
383407
}
384408

385409
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@@ -392,7 +416,7 @@ pub unsafe fn _mm256_fnmsub_pd(a: __m256d, b: __m256d, c: __m256d) -> __m256d {
392416
#[cfg_attr(test, assert_instr(vfnmsub))]
393417
#[stable(feature = "simd_x86", since = "1.27.0")]
394418
pub unsafe fn _mm_fnmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
395-
vfnmsubps(a, b, c)
419+
simd_fma(simd_neg(a), b, simd_neg(c))
396420
}
397421

398422
/// Multiplies packed single-precision (32-bit) floating-point elements in `a`
@@ -405,7 +429,7 @@ pub unsafe fn _mm_fnmsub_ps(a: __m128, b: __m128, c: __m128) -> __m128 {
405429
#[cfg_attr(test, assert_instr(vfnmsub))]
406430
#[stable(feature = "simd_x86", since = "1.27.0")]
407431
pub unsafe fn _mm256_fnmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
408-
vfnmsubps256(a, b, c)
432+
simd_fma(simd_neg(a), b, simd_neg(c))
409433
}
410434

411435
/// Multiplies the lower double-precision (64-bit) floating-point elements in
@@ -420,7 +444,11 @@ pub unsafe fn _mm256_fnmsub_ps(a: __m256, b: __m256, c: __m256) -> __m256 {
420444
#[cfg_attr(test, assert_instr(vfnmsub))]
421445
#[stable(feature = "simd_x86", since = "1.27.0")]
422446
pub unsafe fn _mm_fnmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
423-
vfnmsubsd(a, b, c)
447+
simd_insert!(
448+
a,
449+
0,
450+
_mm_cvtsd_f64(a).mul_add(-_mm_cvtsd_f64(b), -_mm_cvtsd_f64(c))
451+
)
424452
}
425453

426454
/// Multiplies the lower single-precision (32-bit) floating-point elements in
@@ -435,15 +463,15 @@ pub unsafe fn _mm_fnmsub_sd(a: __m128d, b: __m128d, c: __m128d) -> __m128d {
435463
#[cfg_attr(test, assert_instr(vfnmsub))]
436464
#[stable(feature = "simd_x86", since = "1.27.0")]
437465
pub unsafe fn _mm_fnmsub_ss(a: __m128, b: __m128, c: __m128) -> __m128 {
438-
vfnmsubss(a, b, c)
466+
simd_insert!(
467+
a,
468+
0,
469+
_mm_cvtss_f32(a).mul_add(-_mm_cvtss_f32(b), -_mm_cvtss_f32(c))
470+
)
439471
}
440472

441473
#[allow(improper_ctypes)]
442474
extern "C" {
443-
#[link_name = "llvm.x86.fma.vfmadd.sd"]
444-
fn vfmaddsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
445-
#[link_name = "llvm.x86.fma.vfmadd.ss"]
446-
fn vfmaddss(a: __m128, b: __m128, c: __m128) -> __m128;
447475
#[link_name = "llvm.x86.fma.vfmaddsub.pd"]
448476
fn vfmaddsubpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
449477
#[link_name = "llvm.x86.fma.vfmaddsub.pd.256"]
@@ -452,18 +480,6 @@ extern "C" {
452480
fn vfmaddsubps(a: __m128, b: __m128, c: __m128) -> __m128;
453481
#[link_name = "llvm.x86.fma.vfmaddsub.ps.256"]
454482
fn vfmaddsubps256(a: __m256, b: __m256, c: __m256) -> __m256;
455-
#[link_name = "llvm.x86.fma.vfmsub.pd"]
456-
fn vfmsubpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
457-
#[link_name = "llvm.x86.fma.vfmsub.pd.256"]
458-
fn vfmsubpd256(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
459-
#[link_name = "llvm.x86.fma.vfmsub.ps"]
460-
fn vfmsubps(a: __m128, b: __m128, c: __m128) -> __m128;
461-
#[link_name = "llvm.x86.fma.vfmsub.ps.256"]
462-
fn vfmsubps256(a: __m256, b: __m256, c: __m256) -> __m256;
463-
#[link_name = "llvm.x86.fma.vfmsub.sd"]
464-
fn vfmsubsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
465-
#[link_name = "llvm.x86.fma.vfmsub.ss"]
466-
fn vfmsubss(a: __m128, b: __m128, c: __m128) -> __m128;
467483
#[link_name = "llvm.x86.fma.vfmsubadd.pd"]
468484
fn vfmsubaddpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
469485
#[link_name = "llvm.x86.fma.vfmsubadd.pd.256"]
@@ -472,30 +488,6 @@ extern "C" {
472488
fn vfmsubaddps(a: __m128, b: __m128, c: __m128) -> __m128;
473489
#[link_name = "llvm.x86.fma.vfmsubadd.ps.256"]
474490
fn vfmsubaddps256(a: __m256, b: __m256, c: __m256) -> __m256;
475-
#[link_name = "llvm.x86.fma.vfnmadd.pd"]
476-
fn vfnmaddpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
477-
#[link_name = "llvm.x86.fma.vfnmadd.pd.256"]
478-
fn vfnmaddpd256(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
479-
#[link_name = "llvm.x86.fma.vfnmadd.ps"]
480-
fn vfnmaddps(a: __m128, b: __m128, c: __m128) -> __m128;
481-
#[link_name = "llvm.x86.fma.vfnmadd.ps.256"]
482-
fn vfnmaddps256(a: __m256, b: __m256, c: __m256) -> __m256;
483-
#[link_name = "llvm.x86.fma.vfnmadd.sd"]
484-
fn vfnmaddsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
485-
#[link_name = "llvm.x86.fma.vfnmadd.ss"]
486-
fn vfnmaddss(a: __m128, b: __m128, c: __m128) -> __m128;
487-
#[link_name = "llvm.x86.fma.vfnmsub.pd"]
488-
fn vfnmsubpd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
489-
#[link_name = "llvm.x86.fma.vfnmsub.pd.256"]
490-
fn vfnmsubpd256(a: __m256d, b: __m256d, c: __m256d) -> __m256d;
491-
#[link_name = "llvm.x86.fma.vfnmsub.ps"]
492-
fn vfnmsubps(a: __m128, b: __m128, c: __m128) -> __m128;
493-
#[link_name = "llvm.x86.fma.vfnmsub.ps.256"]
494-
fn vfnmsubps256(a: __m256, b: __m256, c: __m256) -> __m256;
495-
#[link_name = "llvm.x86.fma.vfnmsub.sd"]
496-
fn vfnmsubsd(a: __m128d, b: __m128d, c: __m128d) -> __m128d;
497-
#[link_name = "llvm.x86.fma.vfnmsub.ss"]
498-
fn vfnmsubss(a: __m128, b: __m128, c: __m128) -> __m128;
499491
}
500492

501493
#[cfg(test)]

0 commit comments

Comments
 (0)