Skip to content

Commit d54a63d

Browse files
committed
Test refactoring: complex power functions
1 parent 7a4f44c commit d54a63d

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ set(XSIMD_TESTS
150150
test_batch_int.cpp
151151
test_complex_exponential.cpp
152152
test_complex_hyperbolic.cpp
153+
test_complex_power.cpp
153154
test_complex_trigonometric.cpp
154155
test_error_gamma.cpp
155156
test_exponential.cpp

test/test_complex_power.cpp

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and *
3+
* Martin Renou *
4+
* Copyright (c) QuantStack *
5+
* *
6+
* Distributed under the terms of the BSD 3-Clause License. *
7+
* *
8+
* The full license is in the file LICENSE, distributed with this software. *
9+
****************************************************************************/
10+
11+
#include "test_utils.hpp"
12+
13+
template <class B>
14+
class complex_power_test : public testing::Test
15+
{
16+
protected:
17+
18+
using batch_type = B;
19+
using real_batch_type = typename B::real_batch;
20+
using value_type = typename B::value_type;
21+
using real_value_type = typename value_type::value_type;
22+
static constexpr size_t size = B::size;
23+
using vector_type = std::vector<value_type>;
24+
using real_vector_type = std::vector<real_value_type>;
25+
26+
size_t nb_input;
27+
vector_type lhs_nn;
28+
vector_type lhs_pn;
29+
vector_type lhs_np;
30+
vector_type lhs_pp;
31+
vector_type rhs;
32+
vector_type expected;
33+
vector_type res;
34+
35+
complex_power_test()
36+
{
37+
nb_input = 10000 * size;
38+
lhs_nn.resize(nb_input);
39+
lhs_pn.resize(nb_input);
40+
lhs_np.resize(nb_input);
41+
lhs_pp.resize(nb_input);
42+
rhs.resize(nb_input);
43+
for (size_t i = 0; i < nb_input; ++i)
44+
{
45+
real_value_type real = (real_value_type(i) / 4 + real_value_type(1.2) * std::sqrt(real_value_type(i + 0.25)))/ 100;
46+
real_value_type imag = (real_value_type(i) / 7 + real_value_type(1.7) * std::sqrt(real_value_type(i + 0.37))) / 100;
47+
lhs_nn[i] = value_type(-real, -imag);
48+
lhs_pn[i] = value_type(real, -imag);
49+
lhs_np[i] = value_type(-real, imag);
50+
lhs_pp[i] = value_type(real, imag);
51+
rhs[i] = value_type(real_value_type(10.2) / (i + 2) + real_value_type(0.25),
52+
real_value_type(9.1) / (i + 3) + real_value_type(0.45));
53+
}
54+
expected.resize(nb_input);
55+
res.resize(nb_input);
56+
}
57+
58+
void test_abs()
59+
{
60+
real_vector_type real_expected(nb_input), real_res(nb_input);
61+
std::transform(lhs_np.cbegin(), lhs_np.cend(), real_expected.begin(),
62+
[](const value_type& v) { using std::abs; return abs(v); });
63+
batch_type in;
64+
real_batch_type out;
65+
for (size_t i = 0; i < nb_input; i += size)
66+
{
67+
detail::load_batch(in, lhs_np, i);
68+
out = abs(in);
69+
detail::store_batch(out, real_res, i);
70+
}
71+
size_t diff = detail::get_nb_diff(real_res, real_expected);
72+
EXPECT_EQ(diff, 0) << print_function_name("abs");
73+
}
74+
75+
void test_arg()
76+
{
77+
real_vector_type real_expected(nb_input), real_res(nb_input);
78+
std::transform(lhs_np.cbegin(), lhs_np.cend(), real_expected.begin(),
79+
[](const value_type& v) { using std::arg; return arg(v); });
80+
batch_type in;
81+
real_batch_type out;
82+
for (size_t i = 0; i < nb_input; i += size)
83+
{
84+
detail::load_batch(in, lhs_np, i);
85+
out = arg(in);
86+
detail::store_batch(out, real_res, i);
87+
}
88+
size_t diff = detail::get_nb_diff(real_res, real_expected);
89+
EXPECT_EQ(diff, 0) << print_function_name("arg");
90+
}
91+
92+
void test_pow()
93+
{
94+
test_conditional_pow<real_value_type>();
95+
}
96+
97+
void test_sqrt_nn()
98+
{
99+
std::transform(lhs_nn.cbegin(), lhs_nn.cend(), expected.begin(),
100+
[](const value_type& v) { using std::sqrt; return sqrt(v); });
101+
batch_type in, out;
102+
for (size_t i = 0; i < nb_input; i += size)
103+
{
104+
detail::load_batch(in, lhs_nn, i);
105+
out = sqrt(in);
106+
detail::store_batch(out, res, i);
107+
}
108+
size_t diff = detail::get_nb_diff(res, expected);
109+
EXPECT_EQ(diff, 0) << print_function_name("sqrt_nn");
110+
}
111+
112+
void test_sqrt_pn()
113+
{
114+
std::transform(lhs_pn.cbegin(), lhs_pn.cend(), expected.begin(),
115+
[](const value_type& v) { using std::sqrt; return sqrt(v); });
116+
batch_type in, out;
117+
for (size_t i = 0; i < nb_input; i += size)
118+
{
119+
detail::load_batch(in, lhs_pn, i);
120+
out = sqrt(in);
121+
detail::store_batch(out, res, i);
122+
}
123+
size_t diff = detail::get_nb_diff(res, expected);
124+
EXPECT_EQ(diff, 0) << print_function_name("sqrt_pn");
125+
}
126+
127+
void test_sqrt_np()
128+
{
129+
std::transform(lhs_np.cbegin(), lhs_np.cend(), expected.begin(),
130+
[](const value_type& v) { using std::sqrt; return sqrt(v); });
131+
batch_type in, out;
132+
for (size_t i = 0; i < nb_input; i += size)
133+
{
134+
detail::load_batch(in, lhs_np, i);
135+
out = sqrt(in);
136+
detail::store_batch(out, res, i);
137+
}
138+
size_t diff = detail::get_nb_diff(res, expected);
139+
EXPECT_EQ(diff, 0) << print_function_name("sqrt_nn");
140+
}
141+
142+
void test_sqrt_pp()
143+
{
144+
std::transform(lhs_pp.cbegin(), lhs_pp.cend(), expected.begin(),
145+
[](const value_type& v) { using std::sqrt; return sqrt(v); });
146+
batch_type in, out;
147+
for (size_t i = 0; i < nb_input; i += size)
148+
{
149+
detail::load_batch(in, lhs_pp, i);
150+
out = sqrt(in);
151+
detail::store_batch(out, res, i);
152+
}
153+
size_t diff = detail::get_nb_diff(res, expected);
154+
EXPECT_EQ(diff, 0) << print_function_name("sqrt_pp");
155+
}
156+
157+
private:
158+
159+
void test_pow_impl()
160+
{
161+
std::transform(lhs_np.cbegin(), lhs_np.cend(), rhs.cbegin(), expected.begin(),
162+
[](const value_type& l, const value_type& r) { using std::pow; return pow(l, r); });
163+
batch_type lhs_in, rhs_in, out;
164+
for (size_t i = 0; i < nb_input; i += size)
165+
{
166+
detail::load_batch(lhs_in, lhs_np, i);
167+
detail::load_batch(rhs_in, rhs, i);
168+
out = pow(lhs_in, rhs_in);
169+
detail::store_batch(out, res, i);
170+
}
171+
size_t diff = detail::get_nb_diff(res, expected);
172+
EXPECT_EQ(diff, 0) << print_function_name("pow");
173+
}
174+
175+
template <class T, typename std::enable_if<!std::is_same<T, float>::value, int>::type = 0>
176+
void test_conditional_pow()
177+
{
178+
test_pow_impl();
179+
}
180+
181+
template <class T, typename std::enable_if<std::is_same<T, float>::value, int>::type = 0>
182+
void test_conditional_pow()
183+
{
184+
185+
#if (XSIMD_X86_INSTR_SET >= XSIMD_X86_AVX512_VERSION) || (XSIMD_ARM_INSTR_SET >= XSIMD_ARM7_NEON_VERSION)
186+
#if DEBUG_ACCURACY
187+
test_pow_impl();
188+
#endif
189+
#else
190+
test_pow_impl();
191+
#endif
192+
}
193+
};
194+
195+
TYPED_TEST_SUITE(complex_power_test, batch_complex_types, simd_test_names);
196+
197+
TYPED_TEST(complex_power_test, abs)
198+
{
199+
this->test_abs();
200+
}
201+
202+
TYPED_TEST(complex_power_test, arg)
203+
{
204+
this->test_arg();
205+
}
206+
207+
TYPED_TEST(complex_power_test, pow)
208+
{
209+
this->test_pow();
210+
}
211+
212+
TYPED_TEST(complex_power_test, sqrt_nn)
213+
{
214+
this->test_sqrt_nn();
215+
}
216+
217+
218+
TYPED_TEST(complex_power_test, sqrt_pn)
219+
{
220+
this->test_sqrt_pn();
221+
}
222+
223+
TYPED_TEST(complex_power_test, sqrt_np)
224+
{
225+
this->test_sqrt_np();
226+
}
227+
228+
229+
TYPED_TEST(complex_power_test, sqrt_pp)
230+
{
231+
this->test_sqrt_pp();
232+
}

0 commit comments

Comments
 (0)