Skip to content

Commit 7081195

Browse files
kalenedraelserge-sans-paille
authored andcommitted
Add select for batch_bool
1 parent 345b076 commit 7081195

File tree

3 files changed

+108
-14
lines changed

3 files changed

+108
-14
lines changed

include/xsimd/arch/common/xsimd_common_logical.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,35 @@ namespace xsimd
212212
res |= 1ul << i;
213213
return res;
214214
}
215+
216+
// select
217+
namespace detail
218+
{
219+
template <typename T, typename A>
220+
using is_batch_bool_register_same = std::is_same<typename batch_bool<T, A>::register_type, typename batch<T, A>::register_type>;
221+
}
222+
223+
template <class A, class T, typename std::enable_if<detail::is_batch_bool_register_same<T, A>::value, int>::type = 3>
224+
XSIMD_INLINE batch_bool<T, A> select(batch_bool<T, A> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br, requires_arch<common>)
225+
{
226+
using register_type = typename batch_bool<T, A>::register_type;
227+
// Do not cast, but rather reinterpret the masks as batches.
228+
const auto true_v = batch<T, A> { static_cast<register_type>(true_br) };
229+
const auto false_v = batch<T, A> { static_cast<register_type>(false_br) };
230+
return batch_bool<T, A> { select(cond, true_v, false_v) };
231+
}
232+
233+
template <class A, class T, typename std::enable_if<!detail::is_batch_bool_register_same<T, A>::value, int>::type = 3>
234+
XSIMD_INLINE batch_bool<T, A> select(batch_bool<T, A> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br, requires_arch<common>)
235+
{
236+
return (true_br & cond) | (bitwise_andnot(false_br, cond));
237+
}
238+
239+
template <class A, class T, bool... Values>
240+
XSIMD_INLINE batch_bool<T, A> select(batch_bool_constant<T, A, Values...> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br, requires_arch<common>)
241+
{
242+
return (true_br & cond) | (false_br & ~cond);
243+
}
215244
}
216245
}
217246

include/xsimd/types/xsimd_api.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,6 +2138,27 @@ namespace xsimd
21382138
return kernel::select<A>(cond, true_br, false_br, A {});
21392139
}
21402140

2141+
/**
2142+
* @ingroup batch_bool_logical
2143+
*
2144+
* Ternary operator for conditions: selects values from the batches \c true_br or \c false_br
2145+
* depending on the boolean values in the constant batch \c cond. Equivalent to
2146+
* \code{.cpp}
2147+
* for(std::size_t i = 0; i < N; ++i)
2148+
* res[i] = cond[i] ? true_br[i] : false_br[i];
2149+
* \endcode
2150+
* @param cond batch condition.
2151+
* @param true_br batch values for truthy condition.
2152+
* @param false_br batch value for falsy condition.
2153+
* @return the result of the selection.
2154+
*/
2155+
template <class T, class A>
2156+
XSIMD_INLINE batch_bool<T, A> select(batch_bool<T, A> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br) noexcept
2157+
{
2158+
detail::static_check_supported_config<T, A>();
2159+
return kernel::select<A>(cond, true_br, false_br, A {});
2160+
}
2161+
21412162
/**
21422163
* @ingroup batch_cond
21432164
*
@@ -2180,6 +2201,27 @@ namespace xsimd
21802201
return kernel::select<A>(cond, true_br, false_br, A {});
21812202
}
21822203

2204+
/**
2205+
* @ingroup batch_cond
2206+
*
2207+
* Ternary operator for mask batches: selects values from the masks \c true_br or \c false_br
2208+
* depending on the boolean values in the constant batch \c cond. Equivalent to
2209+
* \code{.cpp}
2210+
* for(std::size_t i = 0; i < N; ++i)
2211+
* res[i] = cond[i] ? true_br[i] : false_br[i];
2212+
* \endcode
2213+
* @param cond constant batch condition.
2214+
* @param true_br batch values for truthy condition.
2215+
* @param false_br batch value for falsy condition.
2216+
* @return the result of the selection.
2217+
*/
2218+
template <class T, class A, bool... Values>
2219+
XSIMD_INLINE batch_bool<T, A> select(batch_bool_constant<T, A, Values...> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br) noexcept
2220+
{
2221+
detail::static_check_supported_config<T, A>();
2222+
return kernel::select<A>(cond, true_br, false_br, A {});
2223+
}
2224+
21832225
/**
21842226
* @ingroup batch_data_transfer
21852227
*

test/test_select.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,26 @@ template <class B>
1818
struct select_test
1919
{
2020
using batch_type = B;
21+
using batch_bool_type = typename B::batch_bool_type;
2122
using value_type = typename B::value_type;
2223
using arch_type = typename B::arch_type;
2324
static constexpr size_t size = B::size;
24-
using vector_type = std::vector<value_type>;
25+
static constexpr size_t nb_input = size * 10000;
26+
using vector_type = std::array<value_type, nb_input>;
27+
using vector_bool_type = std::array<bool, nb_input>;
2528

26-
size_t nb_input;
2729
vector_type lhs_input;
2830
vector_type rhs_input;
2931
vector_type expected;
3032
vector_type res;
3133

34+
vector_bool_type lhs_input_b;
35+
vector_bool_type rhs_input_b;
36+
vector_bool_type expected_b;
37+
vector_bool_type res_b;
38+
3239
select_test()
3340
{
34-
nb_input = size * 10000;
35-
lhs_input.resize(nb_input);
36-
rhs_input.resize(nb_input);
3741
auto clamp = [](double v)
3842
{
3943
return static_cast<value_type>(std::min(v, static_cast<double>(std::numeric_limits<value_type>::max())));
@@ -42,28 +46,37 @@ struct select_test
4246
{
4347
lhs_input[i] = clamp(i / 4 + 1.2 * std::sqrt(i + 0.25));
4448
rhs_input[i] = clamp(10.2 / (i + 2) + 0.25);
49+
lhs_input_b[i] = (int)lhs_input[i] % 2;
50+
rhs_input_b[i] = (int)rhs_input[i] % 2;
4551
}
46-
expected.resize(nb_input);
47-
res.resize(nb_input);
4852
}
4953

5054
void test_select_dynamic()
5155
{
5256
for (size_t i = 0; i < nb_input; ++i)
5357
{
5458
expected[i] = lhs_input[i] > value_type(3) ? lhs_input[i] : rhs_input[i];
59+
expected_b[i] = lhs_input[i] > value_type(3) ? lhs_input_b[i] : rhs_input_b[i];
5560
}
5661

57-
batch_type lhs_in, rhs_in, out;
62+
batch_type lhs_in, rhs_in;
63+
batch_bool_type lhs_in_b, rhs_in_b;
5864
for (size_t i = 0; i < nb_input; i += size)
5965
{
6066
detail::load_batch(lhs_in, lhs_input, i);
6167
detail::load_batch(rhs_in, rhs_input, i);
62-
out = xsimd::select(lhs_in > value_type(3), lhs_in, rhs_in);
68+
const auto out = xsimd::select(lhs_in > value_type(3), lhs_in, rhs_in);
6369
detail::store_batch(out, res, i);
70+
71+
detail::load_batch(lhs_in_b, lhs_input_b, i);
72+
detail::load_batch(rhs_in_b, rhs_input_b, i);
73+
const auto out_b = xsimd::select(lhs_in > value_type(3), lhs_in_b, rhs_in_b);
74+
detail::store_batch(out_b, res_b, i);
6475
}
6576
size_t diff = detail::get_nb_diff(res, expected);
77+
size_t diff_b = detail::get_nb_diff(res_b, expected_b);
6678
CHECK_EQ(diff, 0);
79+
CHECK_EQ(diff_b, 0);
6780
}
6881
struct pattern
6982
{
@@ -77,25 +90,35 @@ struct select_test
7790
for (size_t i = 0; i < nb_input; ++i)
7891
{
7992
expected[i] = mask.get(i % size) ? lhs_input[i] : rhs_input[i];
93+
expected_b[i] = mask.get(i % size) ? lhs_input_b[i] : rhs_input_b[i];
8094
}
8195

82-
batch_type lhs_in, rhs_in, out;
96+
batch_type lhs_in, rhs_in;
97+
batch_bool_type lhs_in_b, rhs_in_b;
8398
for (size_t i = 0; i < nb_input; i += size)
8499
{
85100
detail::load_batch(lhs_in, lhs_input, i);
86101
detail::load_batch(rhs_in, rhs_input, i);
87-
out = xsimd::select(mask, lhs_in, rhs_in);
102+
const auto out = xsimd::select(mask, lhs_in, rhs_in);
88103
detail::store_batch(out, res, i);
104+
105+
detail::load_batch(lhs_in_b, lhs_input_b, i);
106+
detail::load_batch(rhs_in_b, rhs_input_b, i);
107+
const auto out_b = xsimd::select(mask, lhs_in_b, rhs_in_b);
108+
detail::store_batch(out_b, res_b, i);
89109
}
90110
size_t diff = detail::get_nb_diff(res, expected);
111+
size_t diff_b = detail::get_nb_diff(res_b, expected_b);
91112
CHECK_EQ(diff, 0);
113+
CHECK_EQ(diff_b, 0);
92114
}
93115
};
94116

95117
TEST_CASE_TEMPLATE("[select]", B, BATCH_TYPES)
96118
{
97-
select_test<B> Test;
98-
SUBCASE("select_dynamic") { Test.test_select_dynamic(); }
99-
SUBCASE("select_static") { Test.test_select_static(); }
119+
// Allocate on heap to avoid stack overflow from excessively large object.
120+
std::unique_ptr<select_test<B>> Test { new select_test<B> };
121+
SUBCASE("select_dynamic") { Test->test_select_dynamic(); }
122+
SUBCASE("select_static") { Test->test_select_static(); }
100123
}
101124
#endif

0 commit comments

Comments
 (0)