Skip to content

Commit 5eb32ad

Browse files
committed
WIP
1 parent d190353 commit 5eb32ad

File tree

2 files changed

+300
-22
lines changed

2 files changed

+300
-22
lines changed

test/test_batch_bool.cpp

Lines changed: 289 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,94 @@
1212
#include "xsimd/xsimd.hpp"
1313
#ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE
1414

15+
#include <array>
1516
#include <functional>
17+
#include <type_traits>
1618
#include <vector>
1719

1820
#include "test_utils.hpp"
1921

2022
namespace xsimd
2123
{
2224

25+
namespace test_detail
26+
{
27+
template <class T, std::size_t N>
28+
struct ct_mask_arch
29+
{
30+
static constexpr bool supported() noexcept { return true; }
31+
static constexpr bool available() noexcept { return true; }
32+
static constexpr std::size_t alignment() noexcept { return 0; }
33+
static constexpr bool requires_alignment() noexcept { return false; }
34+
static constexpr char const* name() noexcept { return "ct_mask_arch"; }
35+
};
36+
37+
template <class T, std::size_t N>
38+
struct ct_mask_register
39+
{
40+
std::array<T, N> data {};
41+
};
42+
43+
struct mask_all_false
44+
{
45+
static constexpr bool get(std::size_t, std::size_t) { return false; }
46+
};
47+
48+
struct mask_all_true
49+
{
50+
static constexpr bool get(std::size_t, std::size_t) { return true; }
51+
};
52+
53+
struct mask_prefix1
54+
{
55+
static constexpr bool get(std::size_t i, std::size_t) { return i < 1; }
56+
};
57+
58+
struct mask_suffix1
59+
{
60+
static constexpr bool get(std::size_t i, std::size_t n) { return i >= (n - 1); }
61+
};
62+
63+
struct mask_ends
64+
{
65+
static constexpr bool get(std::size_t i, std::size_t n)
66+
{
67+
return (i < 1) || (i >= (n - 1));
68+
}
69+
};
70+
71+
struct mask_interleaved
72+
{
73+
static constexpr bool get(std::size_t i, std::size_t) { return (i % 2) == 0; }
74+
};
75+
76+
template <class T>
77+
struct alternating_numeric
78+
{
79+
static constexpr T get(std::size_t i, std::size_t)
80+
{
81+
return (i % 2) ? T(2) : T(1);
82+
}
83+
};
84+
}
85+
86+
namespace types
87+
{
88+
template <class T, std::size_t N>
89+
struct simd_register<T, test_detail::ct_mask_arch<T, N>>
90+
{
91+
using register_type = test_detail::ct_mask_register<T, N>;
92+
register_type data;
93+
constexpr operator register_type() const noexcept { return data; }
94+
};
95+
96+
template <class T, std::size_t N>
97+
struct has_simd_register<T, test_detail::ct_mask_arch<T, N>> : std::true_type
98+
{
99+
};
100+
}
101+
102+
23103
int popcount(int v)
24104
{
25105
// from https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan
@@ -154,6 +234,26 @@ namespace xsimd
154234

155235
}
156236

237+
template <class ValueType, class Arch, std::size_t Size, bool Integral = std::is_integral<ValueType>::value>
238+
struct numeric_constant_helper
239+
{
240+
static void fill(std::array<ValueType, Size>&)
241+
{
242+
}
243+
};
244+
245+
template <class ValueType, class Arch, std::size_t Size>
246+
struct numeric_constant_helper<ValueType, Arch, Size, true>
247+
{
248+
static void fill(std::array<ValueType, Size>& out)
249+
{
250+
constexpr auto numeric_constant = xsimd::make_batch_constant<ValueType,
251+
xsimd::test_detail::alternating_numeric<ValueType>,
252+
Arch>();
253+
numeric_constant.as_batch().store_unaligned(out.data());
254+
}
255+
};
256+
157257
template <class T>
158258
struct batch_bool_test
159259
{
@@ -179,42 +279,108 @@ struct batch_bool_test
179279
: ((bits & 1u ? index : 0u) + sum_indices(bits >> 1, index + 1, remaining - 1));
180280
}
181281

182-
struct ct_false
282+
static constexpr uint32_t low_mask_bits(std::size_t width)
183283
{
184-
static constexpr bool get(std::size_t, std::size_t) { return false; }
185-
};
186-
struct ct_true
284+
return width == 0 ? 0u : (static_cast<uint32_t>(1u << width) - 1u);
285+
}
286+
287+
template <class Mask, class ValueType, bool Enable>
288+
struct splice_checker
187289
{
188-
static constexpr bool get(std::size_t, std::size_t) { return true; }
290+
static void run()
291+
{
292+
}
189293
};
190-
struct ct_prefix1
294+
295+
template <class Mask, class ValueType>
296+
struct splice_checker<Mask, ValueType, true>
191297
{
192-
static constexpr bool get(std::size_t i, std::size_t) { return i < 1; }
298+
static void run()
299+
{
300+
constexpr std::size_t begin = 1;
301+
constexpr std::size_t end = (Mask::size > 3 ? 3 : Mask::size);
302+
constexpr std::size_t length = (end > begin) ? (end - begin) : 0;
303+
using slice_arch = xsimd::test_detail::ct_mask_arch<ValueType, length>;
304+
constexpr auto slice = Mask::template splice<slice_arch, begin, end>();
305+
constexpr uint32_t src_mask = static_cast<uint32_t>(Mask::mask());
306+
constexpr uint32_t expected = (src_mask >> begin) & low_mask_bits(length);
307+
static_assert(static_cast<uint32_t>(slice.mask()) == expected, "splice mask expected");
308+
constexpr uint32_t slice_bits = static_cast<uint32_t>(slice.mask());
309+
constexpr uint32_t shifted_source = src_mask >> begin;
310+
static_assert((length == 0) || ((slice_bits & 1u) == (shifted_source & 1u)), "slice first bit matches");
311+
static_assert((length <= 1) || (((slice_bits >> (length - 1)) & 1u)
312+
== ((shifted_source >> (length - 1)) & 1u)),
313+
"slice last bit matches");
314+
}
193315
};
194-
struct ct_suffix1
316+
317+
template <class Mask, class ValueType, bool Enable>
318+
struct half_checker
195319
{
196-
static constexpr bool get(std::size_t i, std::size_t n) { return i >= (n - 1); }
320+
static void run()
321+
{
322+
}
197323
};
198-
struct ct_ends
324+
325+
template <class Mask, class ValueType>
326+
struct half_checker<Mask, ValueType, true>
199327
{
200-
static constexpr bool get(std::size_t i, std::size_t n) { return (i < 1) || (i >= (n - 1)); }
328+
static void run()
329+
{
330+
constexpr std::size_t total = Mask::size;
331+
constexpr std::size_t mid = total / 2;
332+
using lower_arch = xsimd::test_detail::ct_mask_arch<ValueType, mid>;
333+
using upper_arch = xsimd::test_detail::ct_mask_arch<ValueType, total - mid>;
334+
constexpr auto lower = Mask::template lower_half<lower_arch>();
335+
constexpr auto upper = Mask::template upper_half<upper_arch>();
336+
constexpr uint32_t source_mask = static_cast<uint32_t>(Mask::mask());
337+
static_assert(static_cast<uint32_t>(lower.mask()) == (source_mask & low_mask_bits(mid)),
338+
"lower_half mask matches");
339+
static_assert(static_cast<uint32_t>(upper.mask()) == ((source_mask >> mid) & low_mask_bits(total - mid)),
340+
"upper_half mask matches");
341+
constexpr auto lower_splice = Mask::template splice<lower_arch, 0, mid>();
342+
constexpr auto upper_splice = Mask::template splice<upper_arch, mid, total>();
343+
static_assert(lower.mask() == lower_splice.mask(), "lower_half equals splice");
344+
static_assert(upper.mask() == upper_splice.mask(), "upper_half equals splice");
345+
constexpr uint32_t lower_bits = static_cast<uint32_t>(lower.mask());
346+
constexpr uint32_t upper_bits = static_cast<uint32_t>(upper.mask());
347+
constexpr std::size_t upper_size = decltype(upper)::size;
348+
static_assert((mid == 0) || ((lower_bits & 1u) == (source_mask & 1u)), "lower first element");
349+
static_assert((mid <= 1) || (((lower_bits >> (mid - 1)) & 1u)
350+
== ((source_mask >> (mid - 1)) & 1u)),
351+
"lower last element");
352+
static_assert((upper_size == 0) || ((upper_bits & 1u) == ((source_mask >> mid) & 1u)),
353+
"upper first element");
354+
static_assert((upper_size <= 1) || (((upper_bits >> (upper_size - 1)) & 1u)
355+
== ((source_mask >> (total - 1)) & 1u)),
356+
"upper last element");
357+
}
201358
};
202359

203360
static void run()
204361
{
205362
using value_type = typename B::value_type;
206363
using arch_type = typename B::arch_type;
207-
constexpr auto m_zero = xsimd::make_batch_bool_constant<value_type, ct_false, arch_type>();
208-
constexpr auto m_one = xsimd::make_batch_bool_constant<value_type, ct_true, arch_type>();
209-
constexpr auto m_prefix = xsimd::make_batch_bool_constant<value_type, ct_prefix1, arch_type>();
210-
constexpr auto m_suffix = xsimd::make_batch_bool_constant<value_type, ct_suffix1, arch_type>();
211-
constexpr auto m_ends = xsimd::make_batch_bool_constant<value_type, ct_ends, arch_type>();
364+
constexpr auto m_zero = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_all_false, arch_type>();
365+
constexpr auto m_one = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_all_true, arch_type>();
366+
constexpr auto m_prefix = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_prefix1, arch_type>();
367+
constexpr auto m_suffix = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_suffix1, arch_type>();
368+
constexpr auto m_ends = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_ends, arch_type>();
369+
constexpr auto m_interleaved = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_interleaved, arch_type>();
212370

213371
static_assert((m_zero | m_one).mask() == m_one.mask(), "0|1 == 1");
214372
static_assert((m_zero & m_one).mask() == m_zero.mask(), "0&1 == 0");
215373
static_assert((m_zero ^ m_zero).mask() == m_zero.mask(), "0^0 == 0");
216374
static_assert((m_one ^ m_one).mask() == m_zero.mask(), "1^1 == 0");
217375

376+
static_assert((!m_zero).mask() == m_one.mask(), "!0 == 1");
377+
static_assert((~m_zero).mask() == m_one.mask(), "~0 == 1");
378+
static_assert((!m_one).mask() == m_zero.mask(), "!1 == 0");
379+
static_assert((~m_one).mask() == m_zero.mask(), "~1 == 0");
380+
381+
static_assert(((m_prefix && m_suffix).mask()) == (m_prefix & m_suffix).mask(), "&& consistent");
382+
static_assert(((m_prefix || m_suffix).mask()) == (m_prefix | m_suffix).mask(), "|| consistent");
383+
218384
static_assert((m_prefix | m_suffix).mask() == m_ends.mask(), "prefix|suffix == ends");
219385
static_assert(B::size == 1 || (m_prefix & m_suffix).mask() == m_zero.mask(), "prefix&suffix == 0 when size>1");
220386

@@ -223,28 +389,54 @@ struct batch_bool_test
223389
static_assert(!m_zero.all(), "zero mask all");
224390
static_assert(m_zero.countr_zero() == B::size, "zero mask trailing zeros");
225391
static_assert(m_zero.countl_zero() == B::size, "zero mask leading zeros");
392+
static_assert(m_zero.is_empty(), "zero mask empty");
393+
static_assert(!m_zero.is_full(), "zero mask not full");
226394

227395
static_assert(m_one.all(), "all mask all");
228396
static_assert(m_one.any(), "all mask any");
229397
static_assert(!m_one.none(), "all mask none");
230398
static_assert(m_one.countr_zero() == 0, "all mask trailing zeros");
231399
static_assert(m_one.countl_zero() == 0, "all mask leading zeros");
400+
static_assert(!m_one.is_empty(), "all mask not empty");
401+
static_assert(m_one.is_full(), "all mask full");
402+
403+
constexpr uint32_t prefix_bits = static_cast<uint32_t>(m_prefix.mask());
404+
constexpr uint32_t suffix_bits = static_cast<uint32_t>(m_suffix.mask());
405+
constexpr uint32_t ends_bits_mask = static_cast<uint32_t>(m_ends.mask());
406+
407+
static_assert((B::size == 0) || ((prefix_bits & 1u) != 0u), "prefix first element set");
408+
static_assert((B::size <= 1) || ((prefix_bits & (1u << 1)) == 0u), "prefix second element cleared");
409+
410+
static_assert((B::size == 0) || (((suffix_bits >> (B::size - 1)) & 1u) != 0u), "suffix last element set");
411+
static_assert((B::size <= 1) || ((suffix_bits & 1u) == 0u), "suffix first element cleared");
412+
413+
static_assert((B::size == 0) || ((ends_bits_mask & 1u) != 0u), "ends first element set");
414+
static_assert((B::size == 0) || (((ends_bits_mask >> (B::size - 1)) & 1u) != 0u), "ends last element set");
415+
static_assert((B::size <= 2) || (((ends_bits_mask >> 1) & 1u) == 0u), "ends interior element cleared");
416+
417+
static_assert(std::is_same<decltype(m_prefix.as_batch_bool()), typename B::batch_bool_type>::value,
418+
"as_batch_bool type");
419+
static_assert(std::is_same<decltype(static_cast<typename B::batch_bool_type>(m_prefix)), typename B::batch_bool_type>::value,
420+
"conversion operator type");
232421

233422
// splice API is validated indirectly via arch-specific masked implementations.
234423

235424
constexpr std::size_t prefix_zero = m_prefix.countr_zero();
236425
constexpr std::size_t prefix_one = m_prefix.countr_one();
237426
static_assert(prefix_zero == 0, "prefix mask zero leading zeros from LSB");
238-
static_assert(prefix_one == 1 || B::size == 0, "prefix mask trailing ones count");
427+
static_assert((B::size == 0 ? prefix_one == 0 : prefix_one == 1), "prefix mask trailing ones count");
239428

240429
constexpr std::size_t suffix_zero = m_suffix.countl_zero();
241430
constexpr std::size_t suffix_one = m_suffix.countl_one();
242-
static_assert(suffix_zero == 0 || B::size == 0, "suffix mask leading zeros count");
243-
static_assert(suffix_one == 1 || B::size == 0, "suffix mask trailing ones count");
431+
static_assert(suffix_zero == 0, "suffix mask leading zeros count");
432+
static_assert((B::size == 0 ? suffix_one == 0 : suffix_one == 1), "suffix mask trailing ones count");
433+
434+
splice_checker<decltype(m_interleaved), value_type, (B::size > 1)>::run();
435+
half_checker<decltype(m_ends), value_type, (B::size > 0 && (B::size % 2 == 0))>::run();
244436

245437
constexpr std::size_t ends_bits = m_ends.truncated_mask();
246438
constexpr std::size_t ends_sum = sum_indices(ends_bits, 0u, B::size);
247-
static_assert((B::size <= 1 && ends_sum == 0) || (B::size > 1 && ends_sum == (B::size - 1)), "ends index sum coverage");
439+
static_assert((B::size <= 1) ? (ends_sum == 0) : (ends_sum == (B::size - 1)), "ends index sum coverage");
248440
}
249441
};
250442

@@ -621,6 +813,7 @@ struct batch_bool_test
621813
{
622814
auto bool_g = xsimd::get_bool<typename T::batch_bool_type> {};
623815
using batch_t = typename T::batch_bool_type;
816+
using arch_type = typename batch_type::arch_type;
624817

625818
auto check_mask = [&](batch_t const& m, const char*)
626819
{
@@ -706,6 +899,82 @@ struct batch_bool_test
706899
check_mask(bool_g.half, "half");
707900
check_mask(bool_g.ihalf, "ihalf");
708901
check_mask(bool_g.interspersed, "interspersed");
902+
903+
{
904+
constexpr auto zero_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_all_false, arch_type>();
905+
constexpr auto full_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_all_true, arch_type>();
906+
constexpr auto prefix_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_prefix1, arch_type>();
907+
constexpr auto suffix_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_suffix1, arch_type>();
908+
constexpr auto interleaved_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_interleaved, arch_type>();
909+
910+
CHECK(zero_constant.is_empty());
911+
CHECK_FALSE(zero_constant.is_full());
912+
CHECK(full_constant.is_full());
913+
CHECK_FALSE(full_constant.is_empty());
914+
915+
for (std::size_t i = 0; i < size; ++i)
916+
{
917+
CHECK_EQ(zero_constant.get(i), false);
918+
CHECK_EQ(full_constant.get(i), true);
919+
CHECK_EQ(prefix_constant.get(i), i == 0);
920+
CHECK_EQ(suffix_constant.get(i), (i + 1) == size);
921+
CHECK_EQ(interleaved_constant.get(i), (i % 2) == 0);
922+
}
923+
}
924+
925+
{
926+
constexpr auto ends_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_ends, arch_type>();
927+
bool_array_type converted {};
928+
ends_constant.as_batch_bool().store_unaligned(converted.data());
929+
for (std::size_t i = 0; i < size; ++i)
930+
{
931+
CHECK_EQ(converted[i], ends_constant.get(i));
932+
}
933+
934+
batch_bool_type via_conversion = ends_constant;
935+
via_conversion.store_unaligned(converted.data());
936+
for (std::size_t i = 0; i < size; ++i)
937+
{
938+
CHECK_EQ(converted[i], ends_constant.get(i));
939+
}
940+
}
941+
942+
{
943+
constexpr auto interleaved_constant = xsimd::make_batch_bool_constant<value_type, xsimd::test_detail::mask_interleaved, arch_type>();
944+
constexpr std::size_t begin = 1;
945+
constexpr std::size_t end = (size > 3 ? 3 : size);
946+
constexpr std::size_t length = (end > begin) ? (end - begin) : 0;
947+
using slice_arch = xsimd::test_detail::ct_mask_arch<value_type, length>;
948+
constexpr auto slice = interleaved_constant.template splice<slice_arch, begin, end>();
949+
for (std::size_t i = 0; i < length; ++i)
950+
{
951+
CHECK_EQ(slice.get(i), interleaved_constant.get(begin + i));
952+
}
953+
954+
constexpr std::size_t mid = size / 2;
955+
using lower_arch = xsimd::test_detail::ct_mask_arch<value_type, mid>;
956+
using upper_arch = xsimd::test_detail::ct_mask_arch<value_type, size - mid>;
957+
constexpr auto lower = interleaved_constant.template lower_half<lower_arch>();
958+
constexpr auto upper = interleaved_constant.template upper_half<upper_arch>();
959+
for (std::size_t i = 0; i < mid; ++i)
960+
{
961+
CHECK_EQ(lower.get(i), interleaved_constant.get(i));
962+
}
963+
for (std::size_t i = 0; i < (size - mid); ++i)
964+
{
965+
CHECK_EQ(upper.get(i), interleaved_constant.get(mid + i));
966+
}
967+
}
968+
969+
array_type numeric_values {};
970+
numeric_constant_helper<value_type, arch_type, size>::fill(numeric_values);
971+
if (std::is_integral<value_type>::value)
972+
{
973+
for (std::size_t i = 0; i < size; ++i)
974+
{
975+
CHECK_EQ(numeric_values[i], static_cast<value_type>(xsimd::test_detail::alternating_numeric<value_type>::get(i, size)));
976+
}
977+
}
709978
}
710979

711980
private:

0 commit comments

Comments
 (0)