Skip to content

Commit cadcd8f

Browse files
AntoinePrvserge-sans-paille
authored andcommitted
Add batch_constant comparison and shift operators
1 parent 7f3e01c commit cadcd8f

File tree

3 files changed

+110
-7
lines changed

3 files changed

+110
-7
lines changed

include/xsimd/types/xsimd_batch.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ namespace xsimd
296296
static constexpr std::size_t size = sizeof(types::simd_register<T, A>) / sizeof(T); ///< Number of scalar elements in this batch.
297297

298298
using value_type = bool; ///< Type of the scalar elements within this batch.
299+
using operand_type = T;
299300
using arch_type = A; ///< SIMD Architecture abstracted by this batch.
300301
using register_type = typename base_type::register_type; ///< SIMD register type abstracted by this batch.
301302
using batch_type = batch<T, A>; ///< Associated batch type this batch represents logical operations for.

include/xsimd/types/xsimd_batch_constant.hpp

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#ifndef XSIMD_BATCH_CONSTANT_HPP
1313
#define XSIMD_BATCH_CONSTANT_HPP
1414

15+
#include <cstddef>
16+
1517
#include "./xsimd_batch.hpp"
1618
#include "./xsimd_utils.hpp"
1719

@@ -31,6 +33,7 @@ namespace xsimd
3133
using batch_type = batch_bool<T, A>;
3234
static constexpr std::size_t size = sizeof...(Values);
3335
using value_type = bool;
36+
using operand_type = T;
3437
static_assert(sizeof...(Values) == batch_type::size, "consistent batch size");
3538

3639
public:
@@ -44,7 +47,7 @@ namespace xsimd
4447
*/
4548
constexpr operator batch_type() const noexcept { return as_batch_bool(); }
4649

47-
constexpr bool get(size_t i) const noexcept
50+
constexpr bool get(std::size_t i) const noexcept
4851
{
4952
return std::array<value_type, size> { { Values... } }[i];
5053
}
@@ -76,7 +79,7 @@ namespace xsimd
7679
constexpr bool operator()(bool x, bool y) const { return x ^ y; }
7780
};
7881

79-
template <class F, class SelfPack, class OtherPack, size_t... Indices>
82+
template <class F, class SelfPack, class OtherPack, std::size_t... Indices>
8083
static constexpr batch_bool_constant<T, A, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
8184
apply(detail::index_sequence<Indices...>)
8285
{
@@ -88,7 +91,7 @@ namespace xsimd
8891
-> decltype(apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
8992
{
9093
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
91-
return apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
94+
return {};
9295
}
9396

9497
public:
@@ -148,13 +151,13 @@ namespace xsimd
148151
/**
149152
* @brief Get the @p i th element of this @p batch_constant
150153
*/
151-
constexpr T get(size_t i) const noexcept
154+
constexpr T get(std::size_t i) const noexcept
152155
{
153156
return get(i, std::array<T, size> { Values... });
154157
}
155158

156159
private:
157-
constexpr T get(size_t i, std::array<T, size> const& values) const noexcept
160+
constexpr T get(std::size_t i, std::array<T, size> const& values) const noexcept
158161
{
159162
return values[i];
160163
}
@@ -191,8 +194,16 @@ namespace xsimd
191194
{
192195
constexpr T operator()(T x, T y) const { return x ^ y; }
193196
};
197+
struct binary_rshift
198+
{
199+
constexpr T operator()(T x, T y) const { return x >> y; }
200+
};
201+
struct binary_lshift
202+
{
203+
constexpr T operator()(T x, T y) const { return x << y; }
204+
};
194205

195-
template <class F, class SelfPack, class OtherPack, size_t... Indices>
206+
template <class F, class SelfPack, class OtherPack, std::size_t... Indices>
196207
static constexpr batch_constant<T, A, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
197208
apply(detail::index_sequence<Indices...>)
198209
{
@@ -204,7 +215,7 @@ namespace xsimd
204215
-> decltype(apply<F, std::tuple<std::integral_constant<T, Values>...>, std::tuple<std::integral_constant<T, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
205216
{
206217
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
207-
return apply<F, std::tuple<std::integral_constant<T, Values>...>, std::tuple<std::integral_constant<T, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
218+
return {};
208219
}
209220

210221
public:
@@ -224,9 +235,68 @@ namespace xsimd
224235
MAKE_BINARY_OP(&, binary_and)
225236
MAKE_BINARY_OP(|, binary_or)
226237
MAKE_BINARY_OP(^, binary_xor)
238+
MAKE_BINARY_OP(<<, binary_lshift)
239+
MAKE_BINARY_OP(>>, binary_rshift)
227240

228241
#undef MAKE_BINARY_OP
229242

243+
struct boolean_eq
244+
{
245+
constexpr bool operator()(T x, T y) const { return x == y; }
246+
};
247+
struct boolean_ne
248+
{
249+
constexpr bool operator()(T x, T y) const { return x != y; }
250+
};
251+
struct boolean_gt
252+
{
253+
constexpr bool operator()(T x, T y) const { return x > y; }
254+
};
255+
struct boolean_ge
256+
{
257+
constexpr bool operator()(T x, T y) const { return x >= y; }
258+
};
259+
struct boolean_lt
260+
{
261+
constexpr bool operator()(T x, T y) const { return x < y; }
262+
};
263+
struct boolean_le
264+
{
265+
constexpr bool operator()(T x, T y) const { return x <= y; }
266+
};
267+
268+
template <class F, class SelfPack, class OtherPack, std::size_t... Indices>
269+
static constexpr batch_bool_constant<T, A, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
270+
apply_bool(detail::index_sequence<Indices...>)
271+
{
272+
return {};
273+
}
274+
275+
template <class F, T... OtherValues>
276+
static constexpr auto apply_bool(batch_constant<T, A, Values...>, batch_constant<T, A, OtherValues...>)
277+
-> decltype(apply_bool<F, std::tuple<std::integral_constant<T, Values>...>, std::tuple<std::integral_constant<T, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
278+
{
279+
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
280+
return {};
281+
}
282+
283+
#define MAKE_BINARY_BOOL_OP(OP, NAME) \
284+
template <T... OtherValues> \
285+
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
286+
-> decltype(apply_bool<NAME>(*this, other)) \
287+
{ \
288+
return {}; \
289+
}
290+
291+
MAKE_BINARY_BOOL_OP(==, boolean_eq)
292+
MAKE_BINARY_BOOL_OP(!=, boolean_ne)
293+
MAKE_BINARY_BOOL_OP(<, boolean_lt)
294+
MAKE_BINARY_BOOL_OP(<=, boolean_le)
295+
MAKE_BINARY_BOOL_OP(>, boolean_gt)
296+
MAKE_BINARY_BOOL_OP(>=, boolean_ge)
297+
298+
#undef MAKE_BINARY_BOOL_OP
299+
230300
constexpr batch_constant<T, A, (T)-Values...> operator-() const
231301
{
232302
return {};

test/test_batch_constant.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ struct constant_batch_test
136136
constexpr auto n12_lxor_n3 = n12 ^ n3;
137137
static_assert(std::is_same<decltype(n12_lxor_n3), decltype(n15)>::value, "n12 ^ n3 == n15");
138138

139+
constexpr auto n96 = xsimd::make_batch_constant<value_type, constant<96>, arch_type>();
140+
constexpr auto n12_lshift_n3 = n12 << n3;
141+
static_assert(std::is_same<decltype(n12_lshift_n3), decltype(n96)>::value, "n12 << n3 == n96");
142+
143+
constexpr auto n1 = xsimd::make_batch_constant<value_type, constant<1>, arch_type>();
144+
constexpr auto n12_rshift_n3 = n12 >> n3;
145+
static_assert(std::is_same<decltype(n12_rshift_n3), decltype(n1)>::value, "n12 >> n3 == n1");
146+
139147
constexpr auto n12_uadd = +n12;
140148
static_assert(std::is_same<decltype(n12_uadd), decltype(n12)>::value, "+n12 == n12");
141149

@@ -146,6 +154,30 @@ struct constant_batch_test
146154
constexpr auto n12_usub = -n12;
147155
constexpr auto n12_usub_ = xsimd::make_batch_constant<value_type, constant<(value_type)-12>, arch_type>();
148156
static_assert(std::is_same<decltype(n12_usub), decltype(n12_usub_)>::value, "-n12 == n12_usub");
157+
158+
// comparison operators
159+
using true_batch_type = decltype(xsimd::make_batch_bool_constant<value_type, true, arch_type>());
160+
using false_batch_type = decltype(xsimd::make_batch_bool_constant<value_type, false, arch_type>());
161+
162+
static_assert(std::is_same<typename decltype(n12 == n12)::operand_type, typename decltype(n12)::value_type>::value, "same type");
163+
164+
static_assert(std::is_same<decltype(n12 == n12), true_batch_type>::value, "n12 == n12");
165+
static_assert(std::is_same<decltype(n12 == n3), false_batch_type>::value, "n12 == n3");
166+
167+
static_assert(std::is_same<decltype(n12 != n12), false_batch_type>::value, "n12 != n12");
168+
static_assert(std::is_same<decltype(n12 != n3), true_batch_type>::value, "n12 != n3");
169+
170+
static_assert(std::is_same<decltype(n12 < n12), false_batch_type>::value, "n12 < n12");
171+
static_assert(std::is_same<decltype(n12 < n3), false_batch_type>::value, "n12 < n3");
172+
173+
static_assert(std::is_same<decltype(n12 > n12), false_batch_type>::value, "n12 > n12");
174+
static_assert(std::is_same<decltype(n12 > n3), true_batch_type>::value, "n12 > n3");
175+
176+
static_assert(std::is_same<decltype(n12 <= n12), true_batch_type>::value, "n12 <= n12");
177+
static_assert(std::is_same<decltype(n12 <= n3), false_batch_type>::value, "n12 <= n3");
178+
179+
static_assert(std::is_same<decltype(n12 >= n12), true_batch_type>::value, "n12 >= n12");
180+
static_assert(std::is_same<decltype(n12 >= n3), true_batch_type>::value, "n12 >= n3");
149181
}
150182
};
151183

0 commit comments

Comments
 (0)