Skip to content

Commit 4a5ee2e

Browse files
Support basic operations on batch constant
Add support for &&, ||, ^ and ! for batch_bool_constant. Add support for +, -, *, / and unary - for batch_constant. Fix #930
1 parent 0f51f66 commit 4a5ee2e

File tree

2 files changed

+247
-3
lines changed

2 files changed

+247
-3
lines changed

include/xsimd/types/xsimd_batch_constant.hpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace xsimd
2828
template <class batch_type, bool... Values>
2929
struct batch_bool_constant
3030
{
31+
32+
public:
3133
static constexpr std::size_t size = sizeof...(Values);
3234
using arch_type = typename batch_type::arch_type;
3335
using value_type = bool;
@@ -47,11 +49,67 @@ namespace xsimd
4749

4850
private:
4951
static constexpr int mask_helper(int acc) noexcept { return acc; }
52+
5053
template <class... Tys>
5154
static constexpr int mask_helper(int acc, int mask, Tys... masks) noexcept
5255
{
5356
return mask_helper(acc | mask, (masks << 1)...);
5457
}
58+
59+
struct logical_or
60+
{
61+
constexpr bool operator()(bool x, bool y) const { return x || y; }
62+
};
63+
struct logical_and
64+
{
65+
constexpr bool operator()(bool x, bool y) const { return x && y; }
66+
};
67+
struct logical_xor
68+
{
69+
constexpr bool operator()(bool x, bool y) const { return x ^ y; }
70+
};
71+
72+
template <class F, class SelfPack, class OtherPack, size_t... Indices>
73+
static constexpr batch_bool_constant<batch_type, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
74+
apply(detail::index_sequence<Indices...>)
75+
{
76+
return {};
77+
}
78+
79+
template <class F, bool... OtherValues>
80+
static constexpr auto apply(batch_bool_constant<batch_type, Values...>, batch_bool_constant<batch_type, OtherValues...>)
81+
-> decltype(apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
82+
{
83+
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
84+
return apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
85+
}
86+
87+
public:
88+
#define MAKE_BINARY_OP(OP, NAME) \
89+
template <bool... OtherValues> \
90+
constexpr auto operator OP(batch_bool_constant<batch_type, OtherValues...> other) const \
91+
->decltype(apply<NAME>(*this, other)) \
92+
{ \
93+
return apply<NAME>(*this, other); \
94+
}
95+
96+
MAKE_BINARY_OP(|, logical_or)
97+
MAKE_BINARY_OP(||, logical_or)
98+
MAKE_BINARY_OP(&, logical_and)
99+
MAKE_BINARY_OP(&&, logical_and)
100+
MAKE_BINARY_OP(^, logical_xor)
101+
102+
#undef MAKE_BINARY_OP
103+
104+
constexpr batch_bool_constant<batch_type, !Values...> operator!() const
105+
{
106+
return {};
107+
}
108+
109+
constexpr batch_bool_constant<batch_type, !Values...> operator~() const
110+
{
111+
return {};
112+
}
55113
};
56114

57115
/**
@@ -88,6 +146,89 @@ namespace xsimd
88146
{
89147
return values[i];
90148
}
149+
150+
struct arithmetic_add
151+
{
152+
constexpr value_type operator()(value_type x, value_type y) const { return x + y; }
153+
};
154+
struct arithmetic_sub
155+
{
156+
constexpr value_type operator()(value_type x, value_type y) const { return x - y; }
157+
};
158+
struct arithmetic_mul
159+
{
160+
constexpr value_type operator()(value_type x, value_type y) const { return x * y; }
161+
};
162+
struct arithmetic_div
163+
{
164+
constexpr value_type operator()(value_type x, value_type y) const { return x / y; }
165+
};
166+
struct arithmetic_mod
167+
{
168+
constexpr value_type operator()(value_type x, value_type y) const { return x % y; }
169+
};
170+
struct binary_and
171+
{
172+
constexpr value_type operator()(value_type x, value_type y) const { return x & y; }
173+
};
174+
struct binary_or
175+
{
176+
constexpr value_type operator()(value_type x, value_type y) const { return x | y; }
177+
};
178+
struct binary_xor
179+
{
180+
constexpr value_type operator()(value_type x, value_type y) const { return x ^ y; }
181+
};
182+
183+
template <class F, class SelfPack, class OtherPack, size_t... Indices>
184+
static constexpr batch_constant<batch_type, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
185+
apply(detail::index_sequence<Indices...>)
186+
{
187+
return {};
188+
}
189+
190+
template <class F, value_type... OtherValues>
191+
static constexpr auto apply(batch_constant<batch_type, Values...>, batch_constant<batch_type, OtherValues...>)
192+
-> decltype(apply<F, std::tuple<std::integral_constant<value_type, Values>...>, std::tuple<std::integral_constant<value_type, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
193+
{
194+
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
195+
return apply<F, std::tuple<std::integral_constant<value_type, Values>...>, std::tuple<std::integral_constant<value_type, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
196+
}
197+
198+
public:
199+
#define MAKE_BINARY_OP(OP, NAME) \
200+
template <value_type... OtherValues> \
201+
constexpr auto operator OP(batch_constant<batch_type, OtherValues...> other) const \
202+
->decltype(apply<NAME>(*this, other)) \
203+
{ \
204+
return apply<NAME>(*this, other); \
205+
}
206+
207+
MAKE_BINARY_OP(+, arithmetic_add)
208+
MAKE_BINARY_OP(-, arithmetic_sub)
209+
MAKE_BINARY_OP(*, arithmetic_mul)
210+
MAKE_BINARY_OP(/, arithmetic_div)
211+
MAKE_BINARY_OP(%, arithmetic_mod)
212+
MAKE_BINARY_OP(&, binary_and)
213+
MAKE_BINARY_OP(|, binary_or)
214+
MAKE_BINARY_OP(^, binary_xor)
215+
216+
#undef MAKE_BINARY_OP
217+
218+
constexpr batch_constant<batch_type, (value_type)-Values...> operator-() const
219+
{
220+
return {};
221+
}
222+
223+
constexpr batch_constant<batch_type, (value_type) + Values...> operator+() const
224+
{
225+
return {};
226+
}
227+
228+
constexpr batch_constant<batch_type, (value_type)~Values...> operator~() const
229+
{
230+
return {};
231+
}
91232
};
92233

93234
namespace detail

test/test_batch_constant.cpp

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,69 @@ struct constant_batch_test
6464
CHECK_BATCH_EQ((batch_type)b, expected);
6565
}
6666

67+
template <value_type V>
6768
struct constant
6869
{
6970
static constexpr value_type get(size_t /*index*/, size_t /*size*/)
7071
{
71-
return 3;
72+
return V;
7273
}
7374
};
7475

7576
void test_init_from_constant() const
7677
{
7778
array_type expected;
78-
std::fill(expected.begin(), expected.end(), constant::get(0, 0));
79-
constexpr auto b = xsimd::make_batch_constant<batch_type, constant>();
79+
std::fill(expected.begin(), expected.end(), constant<3>::get(0, 0));
80+
constexpr auto b = xsimd::make_batch_constant<batch_type, constant<3>>();
8081
INFO("batch(value_type)");
8182
CHECK_BATCH_EQ((batch_type)b, expected);
8283
}
84+
85+
void test_ops() const
86+
{
87+
constexpr auto n12 = xsimd::make_batch_constant<batch_type, constant<12>>();
88+
constexpr auto n3 = xsimd::make_batch_constant<batch_type, constant<3>>();
89+
90+
constexpr auto n12_add_n3 = n12 + n3;
91+
constexpr auto n15 = xsimd::make_batch_constant<batch_type, constant<15>>();
92+
static_assert(std::is_same<decltype(n12_add_n3), decltype(n15)>::value, "n12 + n3 == n15");
93+
94+
constexpr auto n12_sub_n3 = n12 - n3;
95+
constexpr auto n9 = xsimd::make_batch_constant<batch_type, constant<9>>();
96+
static_assert(std::is_same<decltype(n12_sub_n3), decltype(n9)>::value, "n12 - n3 == n9");
97+
98+
constexpr auto n12_mul_n3 = n12 * n3;
99+
constexpr auto n36 = xsimd::make_batch_constant<batch_type, constant<36>>();
100+
static_assert(std::is_same<decltype(n12_mul_n3), decltype(n36)>::value, "n12 * n3 == n36");
101+
102+
constexpr auto n12_div_n3 = n12 / n3;
103+
constexpr auto n4 = xsimd::make_batch_constant<batch_type, constant<4>>();
104+
static_assert(std::is_same<decltype(n12_div_n3), decltype(n4)>::value, "n12 / n3 == n4");
105+
106+
constexpr auto n12_mod_n3 = n12 % n3;
107+
constexpr auto n0 = xsimd::make_batch_constant<batch_type, constant<0>>();
108+
static_assert(std::is_same<decltype(n12_mod_n3), decltype(n0)>::value, "n12 % n3 == n0");
109+
110+
constexpr auto n12_land_n3 = n12 & n3;
111+
static_assert(std::is_same<decltype(n12_land_n3), decltype(n0)>::value, "n12 & n3 == n0");
112+
113+
constexpr auto n12_lor_n3 = n12 | n3;
114+
static_assert(std::is_same<decltype(n12_lor_n3), decltype(n15)>::value, "n12 | n3 == n15");
115+
116+
constexpr auto n12_lxor_n3 = n12 ^ n3;
117+
static_assert(std::is_same<decltype(n12_lxor_n3), decltype(n15)>::value, "n12 ^ n3 == n15");
118+
119+
constexpr auto n12_uadd = +n12;
120+
static_assert(std::is_same<decltype(n12_uadd), decltype(n12)>::value, "+n12 == n12");
121+
122+
constexpr auto n12_inv = ~n12;
123+
constexpr auto n12_inv_ = xsimd::make_batch_constant<batch_type, constant<(value_type)~12>>();
124+
static_assert(std::is_same<decltype(n12_inv), decltype(n12_inv_)>::value, "~n12 == n12_inv");
125+
126+
constexpr auto n12_usub = -n12;
127+
constexpr auto n12_usub_ = xsimd::make_batch_constant<batch_type, constant<(value_type)-12>>();
128+
static_assert(std::is_same<decltype(n12_inv), decltype(n12_inv_)>::value, "-n12 == n12_usub");
129+
}
83130
};
84131

85132
TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
@@ -93,6 +140,11 @@ TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
93140
}
94141

95142
SUBCASE("init_from_constant") { Test.test_init_from_constant(); }
143+
144+
SUBCASE("operators")
145+
{
146+
Test.test_ops();
147+
}
96148
}
97149

98150
template <class B>
@@ -144,6 +196,53 @@ struct constant_bool_batch_test
144196
INFO("batch_bool_constant(value_type)");
145197
CHECK_BATCH_EQ((batch_bool_type)b, expected);
146198
}
199+
200+
struct inv_split
201+
{
202+
static constexpr bool get(size_t index, size_t size)
203+
{
204+
return !split().get(index, size);
205+
}
206+
};
207+
208+
template <bool Val>
209+
struct constant
210+
{
211+
static constexpr bool get(size_t /*index*/, size_t /*size*/)
212+
{
213+
return Val;
214+
}
215+
};
216+
217+
void test_ops() const
218+
{
219+
constexpr auto all_true = xsimd::make_batch_bool_constant<batch_type, constant<true>>();
220+
constexpr auto all_false = xsimd::make_batch_bool_constant<batch_type, constant<false>>();
221+
222+
constexpr auto x = xsimd::make_batch_bool_constant<batch_type, split>();
223+
constexpr auto y = xsimd::make_batch_bool_constant<batch_type, inv_split>();
224+
225+
constexpr auto x_or_y = x | y;
226+
static_assert(std::is_same<decltype(x_or_y), decltype(all_true)>::value, "x | y == true");
227+
228+
constexpr auto x_lor_y = x || y;
229+
static_assert(std::is_same<decltype(x_lor_y), decltype(all_true)>::value, "x || y == true");
230+
231+
constexpr auto x_and_y = x & y;
232+
static_assert(std::is_same<decltype(x_and_y), decltype(all_false)>::value, "x & y == false");
233+
234+
constexpr auto x_land_y = x && y;
235+
static_assert(std::is_same<decltype(x_land_y), decltype(all_false)>::value, "x && y == false");
236+
237+
constexpr auto x_xor_y = x ^ y;
238+
static_assert(std::is_same<decltype(x_xor_y), decltype(all_true)>::value, "x ^ y == true");
239+
240+
constexpr auto not_x = !x;
241+
static_assert(std::is_same<decltype(not_x), decltype(y)>::value, "!x == y");
242+
243+
constexpr auto inv_x = ~x;
244+
static_assert(std::is_same<decltype(inv_x), decltype(y)>::value, "~x == y");
245+
}
147246
};
148247

149248
TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
@@ -155,5 +254,9 @@ TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
155254
{
156255
Test.test_init_from_generator_split();
157256
}
257+
SUBCASE("operators")
258+
{
259+
Test.test_ops();
260+
}
158261
}
159262
#endif

0 commit comments

Comments
 (0)