Skip to content

Commit 822b4bb

Browse files
authored
Merge pull request #445 from serge-sans-paille/feature/support-fma-complex
Support and test fma, fma and fmna for complex number
2 parents b3a075f + 072594c commit 822b4bb

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

include/xsimd/math/xsimd_math_complex.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,34 @@ namespace xsimd
815815
return hypot(z.real(), z.imag());
816816
}
817817

818+
static batch_type fma(const batch_type& a, const batch_type& b, const batch_type& c)
819+
{
820+
real_batch res_r = ::xsimd::fms(a.real(), b.real(), ::xsimd::fms(a.imag(), b.imag(), c.real()));
821+
real_batch res_i = ::xsimd::fma(a.real(), b.imag(), ::xsimd::fma(a.imag(), b.real(), c.imag()));
822+
return {res_r, res_i};
823+
}
824+
825+
static batch_type fms(const batch_type& a, const batch_type& b, const batch_type& c)
826+
{
827+
real_batch res_r = ::xsimd::fms(a.real(), b.real(), ::xsimd::fma(a.imag(), b.imag(), c.real()));
828+
real_batch res_i = ::xsimd::fma(a.real(), b.imag(), ::xsimd::fms(a.imag(), b.real(), c.imag()));
829+
return {res_r, res_i};
830+
}
831+
832+
static batch_type fnma(const batch_type& a, const batch_type& b, const batch_type& c)
833+
{
834+
real_batch res_r = - ::xsimd::fms(a.real(), b.real(), ::xsimd::fma(a.imag(), b.imag(), c.real()));
835+
real_batch res_i = - ::xsimd::fma(a.real(), b.imag(), ::xsimd::fms(a.imag(), b.real(), c.imag()));
836+
return {res_r, res_i};
837+
}
838+
839+
static batch_type fnms(const batch_type& a, const batch_type& b, const batch_type& c)
840+
{
841+
real_batch res_r = - ::xsimd::fms(a.real(), b.real(), ::xsimd::fms(a.imag(), b.imag(), c.real()));
842+
real_batch res_i = - ::xsimd::fma(a.real(), b.imag(), ::xsimd::fma(a.imag(), b.real(), c.imag()));
843+
return {res_r, res_i};
844+
}
845+
818846
static batch_type sqrt(const batch_type& z)
819847
{
820848
using rvt = typename real_batch::value_type;

include/xsimd/math/xsimd_scalar.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,31 @@ namespace xsimd
375375
return std::fma(a, b, c);
376376
}
377377

378+
namespace detail
379+
{
380+
template <class C>
381+
inline C fma_complex_scalar_impl(const C& a, const C& b, const C& c)
382+
{
383+
return {fms(a.real(), b.real(), fms(a.imag(), b.imag(), c.real())),
384+
fma(a.real(), b.imag(), fma(a.imag(), b.real(), c.imag()))};
385+
}
386+
}
387+
388+
template <class T>
389+
inline std::complex<T> fma(const std::complex<T>& a, const std::complex<T>& b, const std::complex<T>& c)
390+
{
391+
return detail::fma_complex_scalar_impl(a, b, c);
392+
}
393+
394+
395+
#ifdef XSIMD_ENABLE_XTL_COMPLEX
396+
template <class T, bool i3ec>
397+
inline xtl::xcomplex<T, T, i3ec> fma(const xtl::xcomplex<T, T, i3ec>& a, const xtl::xcomplex<T, T, i3ec>& b, const xtl::xcomplex<T, T, i3ec>& c)
398+
{
399+
return detail::fma_complex_scalar_impl(a, b, c);
400+
}
401+
#endif
402+
378403
inline void sincos(float val, float&s, float& c)
379404
{
380405
#if defined(__APPLE__)

test/test_batch_complex.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,42 @@ class batch_complex_test : public testing::Test
468468
}
469469
}
470470

471+
void test_fused_operations() const
472+
{
473+
// fma
474+
{
475+
array_type expected;
476+
std::transform(lhs.cbegin(), lhs.cend(), rhs.begin(), expected.begin(),
477+
[](const value_type& l, const value_type& r) { return l * r + r; });
478+
batch_type res = xsimd::fma(batch_lhs(), batch_rhs(), batch_rhs());
479+
EXPECT_BATCH_EQ(res, expected) << print_function_name("fma");
480+
}
481+
// fms
482+
{
483+
array_type expected;
484+
std::transform(lhs.cbegin(), lhs.cend(), rhs.begin(), expected.begin(),
485+
[](const value_type& l, const value_type& r) { return l * r - r; });
486+
batch_type res = fms(batch_lhs(), batch_rhs(), batch_rhs());
487+
EXPECT_BATCH_EQ(res, expected) << print_function_name("fms");
488+
}
489+
// fnma
490+
{
491+
array_type expected;
492+
std::transform(lhs.cbegin(), lhs.cend(), rhs.begin(), expected.begin(),
493+
[](const value_type& l, const value_type& r) { return -l * r + r; });
494+
batch_type res = fnma(batch_lhs(), batch_rhs(), batch_rhs());
495+
EXPECT_BATCH_EQ(res, expected) << print_function_name("fnma");
496+
}
497+
// fnms
498+
{
499+
array_type expected;
500+
std::transform(lhs.cbegin(), lhs.cend(), rhs.begin(), expected.begin(),
501+
[](const value_type& l, const value_type& r) { return -l * r - r; });
502+
batch_type res = fnms(batch_lhs(), batch_rhs(), batch_rhs());
503+
EXPECT_BATCH_EQ(res, expected) << print_function_name("fnms");
504+
}
505+
}
506+
471507
private:
472508

473509
batch_type batch_lhs() const
@@ -522,3 +558,7 @@ TYPED_TEST(batch_complex_test, horizontal_operations)
522558
this->test_horizontal_operations();
523559
}
524560

561+
TYPED_TEST(batch_complex_test, fused_operations)
562+
{
563+
this->test_fused_operations();
564+
}

0 commit comments

Comments
 (0)