Skip to content

Commit 7a275ba

Browse files
Generic, simple implementation fox xsimd::expand
Also provide a specialization for avx512f. Related to #975
1 parent d21e8a1 commit 7a275ba

File tree

4 files changed

+166
-0
lines changed

4 files changed

+166
-0
lines changed

include/xsimd/arch/generic/xsimd_generic_memory.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,31 @@ namespace xsimd
6262
return swizzle(z, compress_mask);
6363
}
6464

65+
// expand
66+
namespace detail
67+
{
68+
template <class IT, class A, class I, size_t... Is>
69+
inline batch<IT, A> create_expand_swizzle_mask(I bitmask, ::xsimd::detail::index_sequence<Is...>)
70+
{
71+
batch<IT, A> swizzle_mask(IT(0));
72+
IT j = 0;
73+
(void)std::initializer_list<bool> { ((swizzle_mask = insert(swizzle_mask, j, index<Is>())), (j += ((bitmask >> Is) & 1u)), true)... };
74+
return swizzle_mask;
75+
}
76+
}
77+
78+
template <typename A, typename T>
79+
inline batch<T, A>
80+
expand(batch<T, A> const& x, batch_bool<T, A> const& mask,
81+
kernel::requires_arch<generic>) noexcept
82+
{
83+
constexpr std::size_t size = batch_bool<T, A>::size;
84+
auto bitmask = mask.mask();
85+
auto swizzle_mask = detail::create_expand_swizzle_mask<as_unsigned_integer_t<T>, A>(bitmask, ::xsimd::detail::make_index_sequence<size>());
86+
auto z = swizzle(x, swizzle_mask);
87+
return select(mask, z, batch<T, A>(T(0)));
88+
}
89+
6590
// extract_pair
6691
template <class A, class T>
6792
inline batch<T, A> extract_pair(batch<T, A> const& self, batch<T, A> const& other, std::size_t i, requires_arch<generic>) noexcept

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,38 @@ namespace xsimd
788788
return register_type(~self.data ^ other.data);
789789
}
790790

791+
// expand
792+
template <class A>
793+
inline batch<float, A> expand(batch<float, A> const& self, batch_bool<float, A> const& mask, requires_arch<avx512f>) noexcept
794+
{
795+
return _mm512_maskz_expand_ps(mask.mask(), self);
796+
}
797+
template <class A>
798+
inline batch<double, A> expand(batch<double, A> const& self, batch_bool<double, A> const& mask, requires_arch<avx512f>) noexcept
799+
{
800+
return _mm512_maskz_expand_pd(mask.mask(), self);
801+
}
802+
template <class A>
803+
inline batch<int32_t, A> expand(batch<int32_t, A> const& self, batch_bool<int32_t, A> const& mask, requires_arch<avx512f>) noexcept
804+
{
805+
return _mm512_maskz_expand_epi32(mask.mask(), self);
806+
}
807+
template <class A>
808+
inline batch<uint32_t, A> expand(batch<uint32_t, A> const& self, batch_bool<uint32_t, A> const& mask, requires_arch<avx512f>) noexcept
809+
{
810+
return _mm512_maskz_expand_epi32(mask.mask(), self);
811+
}
812+
template <class A>
813+
inline batch<int64_t, A> expand(batch<int64_t, A> const& self, batch_bool<int64_t, A> const& mask, requires_arch<avx512f>) noexcept
814+
{
815+
return _mm512_maskz_expand_epi64(mask.mask(), self);
816+
}
817+
template <class A>
818+
inline batch<uint64_t, A> expand(batch<uint64_t, A> const& self, batch_bool<uint64_t, A> const& mask, requires_arch<avx512f>) noexcept
819+
{
820+
return _mm512_maskz_expand_epi64(mask.mask(), self);
821+
}
822+
791823
// floor
792824
template <class A>
793825
inline batch<float, A> floor(batch<float, A> const& self, requires_arch<avx512f>) noexcept

include/xsimd/types/xsimd_api.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,19 @@ namespace xsimd
718718
return kernel::exp2<A>(x, A {});
719719
}
720720

721+
/**
722+
* @ingroup batch_data_transfer
723+
*
724+
* Load contiguous elements from \c x and place them in slots selected by \c
725+
* mask, zeroing the other slots
726+
*/
727+
template <class T, class A>
728+
inline batch<T, A> expand(batch<T, A> const& x, batch_bool<T, A> const& mask) noexcept
729+
{
730+
detail::static_check_supported_config<T, A>();
731+
return kernel::expand<A>(x, mask, A {});
732+
}
733+
721734
/**
722735
* @ingroup batch_math
723736
*

test/test_shuffle.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,102 @@ TEST_CASE_TEMPLATE("[compress]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, x
368368
// }
369369
}
370370

371+
template <class B>
372+
struct expand_test
373+
{
374+
using batch_type = B;
375+
using value_type = typename B::value_type;
376+
using mask_batch_type = typename B::batch_bool_type;
377+
378+
static constexpr size_t size = B::size;
379+
std::array<value_type, size> input;
380+
std::array<bool, size> mask;
381+
std::array<value_type, size> expected;
382+
383+
expand_test()
384+
{
385+
for (size_t i = 0; i < size; ++i)
386+
{
387+
input[i] = i;
388+
}
389+
}
390+
391+
void full()
392+
{
393+
std::fill(mask.begin(), mask.end(), true);
394+
395+
for (size_t i = 0; i < size; ++i)
396+
expected[i] = input[i];
397+
398+
auto b = xsimd::expand(
399+
batch_type::load_unaligned(input.data()),
400+
mask_batch_type::load_unaligned(mask.data()));
401+
CHECK_BATCH_EQ(b, expected);
402+
}
403+
404+
void empty()
405+
{
406+
std::fill(mask.begin(), mask.end(), false);
407+
408+
for (size_t i = 0; i < size; ++i)
409+
expected[i] = 0;
410+
411+
auto b = xsimd::expand(
412+
batch_type::load_unaligned(input.data()),
413+
mask_batch_type::load_unaligned(mask.data()));
414+
CHECK_BATCH_EQ(b, expected);
415+
}
416+
417+
void interleave()
418+
{
419+
for (size_t i = 0; i < size; ++i)
420+
mask[i] = i % 2 == 0;
421+
422+
for (size_t i = 0, j = 0; i < size; ++i)
423+
expected[i] = mask[i] ? input[j++] : 0;
424+
425+
auto b = xsimd::expand(
426+
batch_type::load_unaligned(input.data()),
427+
mask_batch_type::load_unaligned(mask.data()));
428+
CHECK_BATCH_EQ(b, expected);
429+
}
430+
431+
void generic()
432+
{
433+
for (size_t i = 0; i < size; ++i)
434+
mask[i] = i % 3 == 0;
435+
436+
for (size_t i = 0, j = 0; i < size; ++i)
437+
expected[i] = mask[i] ? input[j++] : 0;
438+
439+
auto b = xsimd::expand(
440+
batch_type::load_unaligned(input.data()),
441+
mask_batch_type::load_unaligned(mask.data()));
442+
CHECK_BATCH_EQ(b, expected);
443+
}
444+
};
445+
446+
TEST_CASE_TEMPLATE("[expand]", B, BATCH_FLOAT_TYPES, xsimd::batch<uint32_t>, xsimd::batch<int32_t>, xsimd::batch<uint64_t>, xsimd::batch<int64_t>)
447+
{
448+
expand_test<B> Test;
449+
SUBCASE("empty")
450+
{
451+
Test.empty();
452+
}
453+
SUBCASE("full")
454+
{
455+
Test.full();
456+
}
457+
SUBCASE("interleave")
458+
{
459+
Test.interleave();
460+
}
461+
SUBCASE("generic")
462+
{
463+
Test.generic();
464+
}
465+
}
466+
371467
template <class B>
372468
struct shuffle_test
373469
{

0 commit comments

Comments
 (0)