Skip to content

Commit 3fa3c9b

Browse files
committed
1. Adds new masked API compile time masks (store_masked and load_masked)
2. General use case optimization 3. New tests 4. x86 kernels
1 parent ac8a93f commit 3fa3c9b

File tree

13 files changed

+2185
-11
lines changed

13 files changed

+2185
-11
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
#ifndef XSIMD_COMMON_MEMORY_HPP
1313
#define XSIMD_COMMON_MEMORY_HPP
1414

15+
#include "../../types/xsimd_batch_constant.hpp"
16+
#include "./xsimd_common_details.hpp"
1517
#include <algorithm>
18+
#include <array>
1619
#include <complex>
1720
#include <stdexcept>
1821

19-
#include "../../types/xsimd_batch_constant.hpp"
20-
#include "./xsimd_common_details.hpp"
21-
2222
namespace xsimd
2323
{
2424
template <typename T, class A, T... Values>
@@ -348,6 +348,162 @@ namespace xsimd
348348
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349349
}
350350

351+
template <class A, class T>
352+
XSIMD_INLINE batch<T, A> load(T const* mem, aligned_mode, requires_arch<A>) noexcept
353+
{
354+
return load_aligned<A>(mem, convert<T> {}, A {});
355+
}
356+
357+
template <class A, class T>
358+
XSIMD_INLINE batch<T, A> load(T const* mem, unaligned_mode, requires_arch<A>) noexcept
359+
{
360+
return load_unaligned<A>(mem, convert<T> {}, A {});
361+
}
362+
363+
template <class A, class T_in, class T_out, bool... Values, class alignment>
364+
XSIMD_INLINE batch<T_out, A>
365+
load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...>, convert<T_out>, alignment, requires_arch<common>) noexcept
366+
{
367+
constexpr std::size_t size = batch<T_out, A>::size;
368+
alignas(A::alignment()) std::array<T_out, size> buffer {};
369+
constexpr std::array<bool, size> mask { Values... };
370+
371+
for (std::size_t i = 0; i < size; ++i)
372+
buffer[i] = mask[i] ? static_cast<T_out>(mem[i]) : T_out(0);
373+
374+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
375+
}
376+
377+
template <class A, class T_in, class T_out, bool... Values, class alignment>
378+
XSIMD_INLINE void
379+
store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept
380+
{
381+
constexpr std::size_t size = batch<T_in, A>::size;
382+
constexpr std::array<bool, size> mask { Values... };
383+
384+
for (std::size_t i = 0; i < size; ++i)
385+
if (mask[i])
386+
{
387+
mem[i] = static_cast<T_out>(src.get(i));
388+
}
389+
}
390+
391+
template <class A, bool... Values, class Mode>
392+
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem,
393+
batch_bool_constant<int32_t, A, Values...>,
394+
convert<int32_t>,
395+
Mode,
396+
requires_arch<A>) noexcept
397+
{
398+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem),
399+
batch_bool_constant<float, A, Values...> {},
400+
convert<float> {},
401+
Mode {},
402+
A {});
403+
return bitwise_cast<int32_t>(f);
404+
}
405+
406+
template <class A, bool... Values, class Mode>
407+
XSIMD_INLINE batch<uint32_t, A> load_masked(uint32_t const* mem,
408+
batch_bool_constant<uint32_t, A, Values...>,
409+
convert<uint32_t>,
410+
Mode,
411+
requires_arch<A>) noexcept
412+
{
413+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem),
414+
batch_bool_constant<float, A, Values...> {},
415+
convert<float> {},
416+
Mode {},
417+
A {});
418+
return bitwise_cast<uint32_t>(f);
419+
}
420+
421+
template <class A, bool... Values, class Mode>
422+
XSIMD_INLINE batch<int64_t, A> load_masked(int64_t const* mem,
423+
batch_bool_constant<int64_t, A, Values...>,
424+
convert<int64_t>,
425+
Mode,
426+
requires_arch<A>) noexcept
427+
{
428+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem),
429+
batch_bool_constant<double, A, Values...> {},
430+
convert<double> {},
431+
Mode {},
432+
A {});
433+
return bitwise_cast<int64_t>(d);
434+
}
435+
436+
template <class A, bool... Values, class Mode>
437+
XSIMD_INLINE batch<uint64_t, A> load_masked(uint64_t const* mem,
438+
batch_bool_constant<uint64_t, A, Values...>,
439+
convert<uint64_t>,
440+
Mode,
441+
requires_arch<A>) noexcept
442+
{
443+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem),
444+
batch_bool_constant<double, A, Values...> {},
445+
convert<double> {},
446+
Mode {},
447+
A {});
448+
return bitwise_cast<uint64_t>(d);
449+
}
450+
451+
template <class A, bool... Values, class Mode>
452+
XSIMD_INLINE void store_masked(int32_t* mem,
453+
batch<int32_t, A> const& src,
454+
batch_bool_constant<int32_t, A, Values...>,
455+
Mode,
456+
requires_arch<A>) noexcept
457+
{
458+
store_masked<A>(reinterpret_cast<float*>(mem),
459+
bitwise_cast<float>(src),
460+
batch_bool_constant<float, A, Values...> {},
461+
Mode {},
462+
A {});
463+
}
464+
465+
template <class A, bool... Values, class Mode>
466+
XSIMD_INLINE void store_masked(uint32_t* mem,
467+
batch<uint32_t, A> const& src,
468+
batch_bool_constant<uint32_t, A, Values...>,
469+
Mode,
470+
requires_arch<A>) noexcept
471+
{
472+
store_masked<A>(reinterpret_cast<float*>(mem),
473+
bitwise_cast<float>(src),
474+
batch_bool_constant<float, A, Values...> {},
475+
Mode {},
476+
A {});
477+
}
478+
479+
template <class A, bool... Values, class Mode>
480+
XSIMD_INLINE void store_masked(int64_t* mem,
481+
batch<int64_t, A> const& src,
482+
batch_bool_constant<int64_t, A, Values...>,
483+
Mode,
484+
requires_arch<A>) noexcept
485+
{
486+
store_masked<A>(reinterpret_cast<double*>(mem),
487+
bitwise_cast<double>(src),
488+
batch_bool_constant<double, A, Values...> {},
489+
Mode {},
490+
A {});
491+
}
492+
493+
template <class A, bool... Values, class Mode>
494+
XSIMD_INLINE void store_masked(uint64_t* mem,
495+
batch<uint64_t, A> const& src,
496+
batch_bool_constant<uint64_t, A, Values...>,
497+
Mode,
498+
requires_arch<A>) noexcept
499+
{
500+
store_masked<A>(reinterpret_cast<double*>(mem),
501+
bitwise_cast<double>(src),
502+
batch_bool_constant<double, A, Values...> {},
503+
Mode {},
504+
A {});
505+
}
506+
351507
// rotate_right
352508
template <size_t N, class A, class T>
353509
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,160 @@ namespace xsimd
865865
return _mm256_loadu_pd(mem);
866866
}
867867

868+
// load_masked
869+
template <class A, bool... Values, class Mode>
870+
XSIMD_INLINE batch<float, A> load_masked(float const* mem,
871+
batch_bool_constant<float, A, Values...> mask,
872+
convert<float>,
873+
Mode,
874+
requires_arch<avx>) noexcept
875+
{
876+
constexpr auto mb = batch_bool_constant<float, A, Values...>::mask();
877+
XSIMD_IF_CONSTEXPR(mask.none())
878+
{
879+
return _mm256_setzero_ps();
880+
}
881+
else XSIMD_IF_CONSTEXPR(mask.all())
882+
{
883+
return load<A>(mem, Mode {});
884+
}
885+
// confined to lower 128-bit half (4 lanes) → forward to SSE2
886+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4)
887+
{
888+
constexpr auto mlo = mask.template lower_half<sse4_2>();
889+
const auto lo = load_masked(mem, mlo, convert<float> {}, Mode {}, sse4_2 {});
890+
return _mm256_zextps128_ps256(lo);
891+
}
892+
// confined to upper 128-bit half (4 lanes) → forward to SSE2
893+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4)
894+
{
895+
constexpr auto mhi = mask.template upper_half<sse4_2>();
896+
const auto hi = load_masked(mem + 4, mhi, convert<float> {}, Mode {}, sse4_2 {});
897+
return _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 1);
898+
}
899+
else
900+
{
901+
// crossing 128-bit boundary → use 256-bit masked load
902+
return _mm256_maskload_ps(mem, from_mask(batch_bool<int32_t, A> {}, mb, avx {}));
903+
}
904+
}
905+
906+
template <class A, bool... Values, class Mode>
907+
XSIMD_INLINE batch<double, A> load_masked(double const* mem,
908+
batch_bool_constant<double, A, Values...> mask,
909+
convert<double>,
910+
Mode,
911+
requires_arch<avx>) noexcept
912+
{
913+
constexpr auto mb = batch_bool_constant<double, A, Values...>::mask();
914+
XSIMD_IF_CONSTEXPR(mask.none())
915+
{
916+
return _mm256_setzero_pd();
917+
}
918+
else XSIMD_IF_CONSTEXPR(mask.all())
919+
{
920+
return load<A>(mem, Mode {});
921+
}
922+
// confined to lower 128-bit half (2 lanes) → forward to SSE2
923+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2)
924+
{
925+
constexpr auto mlo = mask.template lower_half<sse4_2>();
926+
const auto lo = load_masked(mem, mlo, convert<double> {}, Mode {}, sse4_2 {});
927+
return _mm256_zextpd128_pd256(lo);
928+
}
929+
// confined to upper 128-bit half (2 lanes) → forward to SSE2
930+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2)
931+
{
932+
constexpr auto mhi = mask.template upper_half<sse4_2>();
933+
const auto hi = load_masked(mem + 2, mhi, convert<double> {}, Mode {}, sse4_2 {});
934+
return _mm256_insertf128_pd(_mm256_setzero_pd(), hi, 1);
935+
}
936+
else
937+
{
938+
// crossing 128-bit boundary → use 256-bit masked load
939+
return _mm256_maskload_pd(mem, from_mask(batch_bool<int64_t, A> {}, mb, avx {}));
940+
}
941+
}
942+
943+
// store_masked
944+
template <class A, bool... Values, class Mode>
945+
XSIMD_INLINE void store_masked(float* mem,
946+
batch<float, A> const& src,
947+
batch_bool_constant<float, A, Values...> mask,
948+
Mode,
949+
requires_arch<avx>) noexcept
950+
{
951+
constexpr auto mb = batch_bool_constant<float, A, Values...>::mask();
952+
XSIMD_IF_CONSTEXPR(mask.none())
953+
{
954+
return;
955+
}
956+
else XSIMD_IF_CONSTEXPR(mask.all())
957+
{
958+
src.store(mem, Mode {});
959+
}
960+
// confined to lower 128-bit half (4 lanes) → forward to SSE2
961+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4)
962+
{
963+
auto mlo = mask.template lower_half<sse4_2>();
964+
batch<float, sse4_2> lo(_mm256_castps256_ps128(src));
965+
store_masked<sse4_2>(mem, lo, mlo, Mode {}, sse4_2 {});
966+
}
967+
// confined to upper 128-bit half (4 lanes) → forward to SSE2
968+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4)
969+
{
970+
auto mhi = mask.template upper_half<sse4_2>();
971+
batch<float, sse4_2> hi(_mm256_extractf128_ps(src, 1));
972+
store_masked<sse4_2>(mem + 4, hi, mhi, Mode {}, sse4_2 {});
973+
}
974+
else
975+
{
976+
__m256i m256 = from_mask(batch_bool<int32_t, A> {}, mb, avx {});
977+
_mm256_maskstore_ps(mem, m256, src);
978+
}
979+
}
980+
981+
982+
983+
template <class A, bool... Values, class Mode>
984+
XSIMD_INLINE void store_masked(double* mem,
985+
batch<double, A> const& src,
986+
batch_bool_constant<double, A, Values...> mask,
987+
Mode,
988+
requires_arch<avx>) noexcept
989+
{
990+
constexpr auto mb = batch_bool_constant<double, A, Values...>::mask();
991+
XSIMD_IF_CONSTEXPR(mask.none())
992+
{
993+
return;
994+
}
995+
else XSIMD_IF_CONSTEXPR(mask.all())
996+
{
997+
src.store(mem, Mode {});
998+
}
999+
// confined to lower 128-bit half (2 lanes) → forward to SSE2
1000+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2)
1001+
{
1002+
auto mlo = mask.template lower_half<sse2>();
1003+
batch<double, sse2> lo(_mm256_castpd256_pd128(src));
1004+
store_masked<sse2>(mem, lo, mlo, Mode {}, sse4_2 {});
1005+
}
1006+
// confined to upper 128-bit half (2 lanes) → forward to SSE2
1007+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2)
1008+
{
1009+
auto mhi = mask.template upper_half<sse2>();
1010+
batch<double, sse2> hi(_mm256_extractf128_pd(src, 1));
1011+
store_masked<sse2>(mem + 2, hi, mhi, Mode {}, sse4_2 {});
1012+
}
1013+
else
1014+
{
1015+
__m256i m256 = from_mask(batch_bool<int64_t, A> {}, mb, avx {});
1016+
_mm256_maskstore_pd(mem, m256, src);
1017+
}
1018+
}
1019+
1020+
1021+
8681022
// lt
8691023
template <class A>
8701024
XSIMD_INLINE batch_bool<float, A> lt(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx>) noexcept

0 commit comments

Comments
 (0)