Skip to content

Commit d54bfd8

Browse files
committed
remove batch_cast for batch_bool, implement select for batch_bool
1 parent 26e5e36 commit d54bfd8

File tree

4 files changed

+108
-34
lines changed

4 files changed

+108
-34
lines changed

include/xsimd/arch/common/xsimd_common_logical.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,35 @@ namespace xsimd
212212
res |= 1ul << i;
213213
return res;
214214
}
215+
216+
// select
217+
namespace detail
218+
{
219+
template <typename T, typename A>
220+
using is_batch_bool_register_same = std::is_same<typename batch_bool<T, A>::register_type, typename batch<T, A>::register_type>;
221+
}
222+
223+
template <class A, class T, typename std::enable_if<detail::is_batch_bool_register_same<T, A>::value, int>::type = 3>
224+
XSIMD_INLINE batch_bool<T, A> select(batch_bool<T, A> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br, requires_arch<common>)
225+
{
226+
using register_type = typename batch_bool<T, A>::register_type;
227+
// Do not cast, but rather reinterpret the masks as batches.
228+
const auto true_v = batch<T, A> { static_cast<register_type>(true_br) };
229+
const auto false_v = batch<T, A> { static_cast<register_type>(false_br) };
230+
return batch_bool<T, A> { select(cond, true_v, false_v) };
231+
}
232+
233+
template <class A, class T, typename std::enable_if<!detail::is_batch_bool_register_same<T, A>::value, int>::type = 3>
234+
XSIMD_INLINE batch_bool<T, A> select(batch_bool<T, A> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br, requires_arch<common>)
235+
{
236+
return (true_br & cond) | (bitwise_andnot(false_br, cond));
237+
}
238+
239+
template <class A, class T, bool... Values>
240+
XSIMD_INLINE batch_bool<T, A> select(batch_bool_constant<T, A, Values...> const&, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br, requires_arch<common>)
241+
{
242+
return select<A>(batch_bool<T, A> { Values... }, true_br, false_br, A {});
243+
}
215244
}
216245
}
217246

include/xsimd/types/xsimd_api.hpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,19 +264,6 @@ namespace xsimd
264264
return kernel::batch_cast<A>(x, batch<T_out, A> {}, A {});
265265
}
266266

267-
/**
268-
* @ingroup batch_conversion
269-
*
270-
* Perform a static_cast from \c T_in to \c T_out on \c x.
271-
* @param x batch of \c T_in
272-
* @return \c x cast to \c T_out
273-
*/
274-
template <class T_out, class T_in, class A>
275-
XSIMD_INLINE batch_bool<T_out, A> batch_cast(batch_bool<T_in, A> const& x) noexcept
276-
{
277-
return batch_bool_cast<T_out>(x);
278-
}
279-
280267
/**
281268
* @ingroup batch_miscellaneous
282269
*
@@ -2112,6 +2099,27 @@ namespace xsimd
21122099
return kernel::select<A>(cond, true_br, false_br, A {});
21132100
}
21142101

2102+
/**
2103+
* @ingroup batch_bool_logical
2104+
*
2105+
* Ternary operator for conditions: selects values from the batches \c true_br or \c false_br
2106+
* depending on the boolean values in the constant batch \c cond. Equivalent to
2107+
* \code{.cpp}
2108+
* for(std::size_t i = 0; i < N; ++i)
2109+
* res[i] = cond[i] ? true_br[i] : false_br[i];
2110+
* \endcode
2111+
* @param cond batch condition.
2112+
* @param true_br batch values for truthy condition.
2113+
* @param false_br batch value for falsy condition.
2114+
* @return the result of the selection.
2115+
*/
2116+
template <class T, class A>
2117+
XSIMD_INLINE batch_bool<T, A> select(batch_bool<T, A> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br) noexcept
2118+
{
2119+
detail::static_check_supported_config<T, A>();
2120+
return kernel::select<A>(cond, true_br, false_br, A {});
2121+
}
2122+
21152123
/**
21162124
* @ingroup batch_cond
21172125
*
@@ -2154,6 +2162,27 @@ namespace xsimd
21542162
return kernel::select<A>(cond, true_br, false_br, A {});
21552163
}
21562164

2165+
/**
2166+
* @ingroup batch_cond
2167+
*
2168+
* Ternary operator for mask batches: selects values from the masks \c true_br or \c false_br
2169+
* depending on the boolean values in the constant batch \c cond. Equivalent to
2170+
* \code{.cpp}
2171+
* for(std::size_t i = 0; i < N; ++i)
2172+
* res[i] = cond[i] ? true_br[i] : false_br[i];
2173+
* \endcode
2174+
* @param cond constant batch condition.
2175+
* @param true_br batch values for truthy condition.
2176+
* @param false_br batch value for falsy condition.
2177+
* @return the result of the selection.
2178+
*/
2179+
template <class T, class A, bool... Values>
2180+
XSIMD_INLINE batch_bool<T, A> select(batch_bool_constant<T, A, Values...> const& cond, batch_bool<T, A> const& true_br, batch_bool<T, A> const& false_br) noexcept
2181+
{
2182+
detail::static_check_supported_config<T, A>();
2183+
return kernel::select<A>(cond, true_br, false_br, A {});
2184+
}
2185+
21572186
/**
21582187
* @ingroup batch_data_transfer
21592188
*

test/test_batch_cast.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -354,22 +354,16 @@ struct batch_cast_test
354354
using B_common_out = xsimd::batch_bool<T_out>;
355355

356356
B_common_in all_true_in(true);
357-
B_common_out all_true_res0 = xsimd::batch_bool_cast<T_out>(all_true_in);
358-
B_common_out all_true_res1 = xsimd::batch_cast<T_out>(all_true_in);
357+
B_common_out all_true_res = xsimd::batch_bool_cast<T_out>(all_true_in);
359358
INFO(name);
360-
CHECK_SCALAR_EQ(all_true_res0.get(0), true);
361-
CHECK_SCALAR_EQ(all_true_res1.get(0), true);
359+
CHECK_SCALAR_EQ(all_true_res.get(0), true);
362360
CHECK_SCALAR_EQ(xsimd::batch_bool_cast<B_out>(true), true);
363-
CHECK_SCALAR_EQ(xsimd::batch_cast<B_out>(true), true);
364361

365362
B_common_in all_false_in(false);
366-
B_common_out all_false_res0 = xsimd::batch_bool_cast<T_out>(all_false_in);
367-
B_common_out all_false_res1 = xsimd::batch_cast<T_out>(all_false_in);
363+
B_common_out all_false_res = xsimd::batch_bool_cast<T_out>(all_false_in);
368364
INFO(name);
369-
CHECK_SCALAR_EQ(all_false_res0.get(0), false);
370-
CHECK_SCALAR_EQ(all_false_res1.get(0), false);
365+
CHECK_SCALAR_EQ(all_false_res.get(0), false);
371366
CHECK_SCALAR_EQ(xsimd::batch_bool_cast<B_out>(false), false);
372-
CHECK_SCALAR_EQ(xsimd::batch_cast<B_out>(false), false);
373367
}
374368
};
375369

test/test_select.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,26 @@ template <class B>
1818
struct select_test
1919
{
2020
using batch_type = B;
21+
using batch_bool_type = typename B::batch_bool_type;
2122
using value_type = typename B::value_type;
2223
using arch_type = typename B::arch_type;
2324
static constexpr size_t size = B::size;
24-
using vector_type = std::vector<value_type>;
25+
static constexpr size_t nb_input = size * 10000;
26+
using vector_type = std::array<value_type, nb_input>;
27+
using vector_bool_type = std::array<bool, nb_input>;
2528

26-
size_t nb_input;
2729
vector_type lhs_input;
2830
vector_type rhs_input;
2931
vector_type expected;
3032
vector_type res;
3133

34+
vector_bool_type lhs_input_b;
35+
vector_bool_type rhs_input_b;
36+
vector_bool_type expected_b;
37+
vector_bool_type res_b;
38+
3239
select_test()
3340
{
34-
nb_input = size * 10000;
35-
lhs_input.resize(nb_input);
36-
rhs_input.resize(nb_input);
3741
auto clamp = [](double v)
3842
{
3943
return static_cast<value_type>(std::min(v, static_cast<double>(std::numeric_limits<value_type>::max())));
@@ -42,28 +46,37 @@ struct select_test
4246
{
4347
lhs_input[i] = clamp(i / 4 + 1.2 * std::sqrt(i + 0.25));
4448
rhs_input[i] = clamp(10.2 / (i + 2) + 0.25);
49+
lhs_input_b[i] = (int)lhs_input[i] % 2;
50+
rhs_input_b[i] = (int)rhs_input[i] % 2;
4551
}
46-
expected.resize(nb_input);
47-
res.resize(nb_input);
4852
}
4953

5054
void test_select_dynamic()
5155
{
5256
for (size_t i = 0; i < nb_input; ++i)
5357
{
5458
expected[i] = lhs_input[i] > value_type(3) ? lhs_input[i] : rhs_input[i];
59+
expected_b[i] = lhs_input[i] > value_type(3) ? lhs_input_b[i] : rhs_input_b[i];
5560
}
5661

57-
batch_type lhs_in, rhs_in, out;
62+
batch_type lhs_in, rhs_in;
63+
batch_bool_type lhs_in_b, rhs_in_b;
5864
for (size_t i = 0; i < nb_input; i += size)
5965
{
6066
detail::load_batch(lhs_in, lhs_input, i);
6167
detail::load_batch(rhs_in, rhs_input, i);
62-
out = xsimd::select(lhs_in > value_type(3), lhs_in, rhs_in);
68+
const auto out = xsimd::select(lhs_in > value_type(3), lhs_in, rhs_in);
6369
detail::store_batch(out, res, i);
70+
71+
detail::load_batch(lhs_in_b, lhs_input_b, i);
72+
detail::load_batch(rhs_in_b, rhs_input_b, i);
73+
const auto out_b = xsimd::select(lhs_in > value_type(3), lhs_in_b, rhs_in_b);
74+
detail::store_batch(out_b, res_b, i);
6475
}
6576
size_t diff = detail::get_nb_diff(res, expected);
77+
size_t diff_b = detail::get_nb_diff(res_b, expected_b);
6678
CHECK_EQ(diff, 0);
79+
CHECK_EQ(diff_b, 0);
6780
}
6881
struct pattern
6982
{
@@ -77,18 +90,27 @@ struct select_test
7790
for (size_t i = 0; i < nb_input; ++i)
7891
{
7992
expected[i] = mask.get(i % size) ? lhs_input[i] : rhs_input[i];
93+
expected_b[i] = mask.get(i % size) ? lhs_input_b[i] : rhs_input_b[i];
8094
}
8195

82-
batch_type lhs_in, rhs_in, out;
96+
batch_type lhs_in, rhs_in;
97+
batch_bool_type lhs_in_b, rhs_in_b;
8398
for (size_t i = 0; i < nb_input; i += size)
8499
{
85100
detail::load_batch(lhs_in, lhs_input, i);
86101
detail::load_batch(rhs_in, rhs_input, i);
87-
out = xsimd::select(mask, lhs_in, rhs_in);
102+
const auto out = xsimd::select(mask, lhs_in, rhs_in);
88103
detail::store_batch(out, res, i);
104+
105+
detail::load_batch(lhs_in_b, lhs_input_b, i);
106+
detail::load_batch(rhs_in_b, rhs_input_b, i);
107+
const auto out_b = xsimd::select(mask, lhs_in_b, rhs_in_b);
108+
detail::store_batch(out_b, res_b, i);
89109
}
90110
size_t diff = detail::get_nb_diff(res, expected);
111+
size_t diff_b = detail::get_nb_diff(res_b, expected_b);
91112
CHECK_EQ(diff, 0);
113+
CHECK_EQ(diff_b, 0);
92114
}
93115
};
94116

0 commit comments

Comments
 (0)