Skip to content

Commit 49c9cda

Browse files
committed
1. Adds new masked API compile time masks (store_masked and load_masked)
2. General use case optimization 3. New tests 4. x86 kernels
1 parent e808799 commit 49c9cda

15 files changed

+2460
-35
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_arithmetic.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <type_traits>
1818

1919
#include "./xsimd_common_details.hpp"
20+
#include "../../types/xsimd_batch_constant.hpp"
2021

2122
namespace xsimd
2223
{

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
#ifndef XSIMD_COMMON_MEMORY_HPP
1313
#define XSIMD_COMMON_MEMORY_HPP
1414

15+
#include "../../types/xsimd_batch_constant.hpp"
16+
#include "./xsimd_common_details.hpp"
1517
#include <algorithm>
18+
#include <array>
1619
#include <complex>
1720
#include <stdexcept>
1821

19-
#include "../../types/xsimd_batch_constant.hpp"
20-
#include "./xsimd_common_details.hpp"
21-
2222
namespace xsimd
2323
{
2424
template <typename T, class A, T... Values>
@@ -348,6 +348,102 @@ namespace xsimd
348348
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349349
}
350350

351+
template <class A, class T>
352+
XSIMD_INLINE batch<T, A> load(T const* mem, aligned_mode, requires_arch<A>) noexcept
353+
{
354+
return load_aligned<A>(mem, convert<T> {}, A {});
355+
}
356+
357+
template <class A, class T>
358+
XSIMD_INLINE batch<T, A> load(T const* mem, unaligned_mode, requires_arch<A>) noexcept
359+
{
360+
return load_unaligned<A>(mem, convert<T> {}, A {});
361+
}
362+
363+
template <class A, class T_in, class T_out, bool... Values, class alignment>
364+
XSIMD_INLINE batch<T_out, A>
365+
load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...>, convert<T_out>, alignment, requires_arch<common>) noexcept
366+
{
367+
constexpr std::size_t size = batch<T_out, A>::size;
368+
alignas(A::alignment()) std::array<T_out, size> buffer {};
369+
constexpr std::array<bool, size> mask { Values... };
370+
371+
for (std::size_t i = 0; i < size; ++i)
372+
buffer[i] = mask[i] ? static_cast<T_out>(mem[i]) : T_out(0);
373+
374+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
375+
}
376+
377+
template <class A, class T_in, class T_out, bool... Values, class alignment>
378+
XSIMD_INLINE void
379+
store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept
380+
{
381+
constexpr std::size_t size = batch<T_in, A>::size;
382+
constexpr std::array<bool, size> mask { Values... };
383+
384+
for (std::size_t i = 0; i < size; ++i)
385+
if (mask[i])
386+
{
387+
mem[i] = static_cast<T_out>(src.get(i));
388+
}
389+
}
390+
391+
template <class A, bool... Values, class Mode>
392+
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem, batch_bool_constant<int32_t, A, Values...>, convert<int32_t>, Mode, requires_arch<A>) noexcept
393+
{
394+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem), batch_bool_constant<float, A, Values...> {}, convert<float> {}, Mode {}, A {});
395+
return bitwise_cast<int32_t>(f);
396+
}
397+
398+
template <class A, bool... Values, class Mode>
399+
XSIMD_INLINE batch<uint32_t, A> load_masked(uint32_t const* mem, batch_bool_constant<uint32_t, A, Values...>, convert<uint32_t>, Mode, requires_arch<A>) noexcept
400+
{
401+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem), batch_bool_constant<float, A, Values...> {}, convert<float> {}, Mode {}, A {});
402+
return bitwise_cast<uint32_t>(f);
403+
}
404+
405+
template <class A, bool... Values, class Mode>
406+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, batch<int64_t, A>>::type
407+
load_masked(int64_t const* mem, batch_bool_constant<int64_t, A, Values...>, convert<int64_t>, Mode, requires_arch<A>) noexcept
408+
{
409+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem), batch_bool_constant<double, A, Values...> {}, convert<double> {}, Mode {}, A {});
410+
return bitwise_cast<int64_t>(d);
411+
}
412+
413+
template <class A, bool... Values, class Mode>
414+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, batch<uint64_t, A>>::type
415+
load_masked(uint64_t const* mem, batch_bool_constant<uint64_t, A, Values...>, convert<uint64_t>, Mode, requires_arch<A>) noexcept
416+
{
417+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem), batch_bool_constant<double, A, Values...> {}, convert<double> {}, Mode {}, A {});
418+
return bitwise_cast<uint64_t>(d);
419+
}
420+
421+
template <class A, bool... Values, class Mode>
422+
XSIMD_INLINE void store_masked(int32_t* mem, batch<int32_t, A> const& src, batch_bool_constant<int32_t, A, Values...>, Mode, requires_arch<A>) noexcept
423+
{
424+
store_masked<A>(reinterpret_cast<float*>(mem), bitwise_cast<float>(src), batch_bool_constant<float, A, Values...> {}, Mode {}, A {});
425+
}
426+
427+
template <class A, bool... Values, class Mode>
428+
XSIMD_INLINE void store_masked(uint32_t* mem, batch<uint32_t, A> const& src, batch_bool_constant<uint32_t, A, Values...>, Mode, requires_arch<A>) noexcept
429+
{
430+
store_masked<A>(reinterpret_cast<float*>(mem), bitwise_cast<float>(src), batch_bool_constant<float, A, Values...> {}, Mode {}, A {});
431+
}
432+
433+
template <class A, bool... Values, class Mode>
434+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, void>::type
435+
store_masked(int64_t* mem, batch<int64_t, A> const& src, batch_bool_constant<int64_t, A, Values...>, Mode, requires_arch<A>) noexcept
436+
{
437+
store_masked<A>(reinterpret_cast<double*>(mem), bitwise_cast<double>(src), batch_bool_constant<double, A, Values...> {}, Mode {}, A {});
438+
}
439+
440+
template <class A, bool... Values, class Mode>
441+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, void>::type
442+
store_masked(uint64_t* mem, batch<uint64_t, A> const& src, batch_bool_constant<uint64_t, A, Values...>, Mode, requires_arch<A>) noexcept
443+
{
444+
store_masked<A>(reinterpret_cast<double*>(mem), bitwise_cast<double>(src), batch_bool_constant<double, A, Values...> {}, Mode {}, A {});
445+
}
446+
351447
// rotate_right
352448
template <size_t N, class A, class T>
353449
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 175 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,35 @@ namespace xsimd
3636

3737
namespace detail
3838
{
39-
XSIMD_INLINE void split_avx(__m256i val, __m128i& low, __m128i& high) noexcept
39+
XSIMD_INLINE __m128i lower_half(__m256i self) noexcept
4040
{
41-
low = _mm256_castsi256_si128(val);
42-
high = _mm256_extractf128_si256(val, 1);
41+
return _mm256_castsi256_si128(self);
4342
}
44-
XSIMD_INLINE void split_avx(__m256 val, __m128& low, __m128& high) noexcept
43+
XSIMD_INLINE __m128 lower_half(__m256 self) noexcept
4544
{
46-
low = _mm256_castps256_ps128(val);
47-
high = _mm256_extractf128_ps(val, 1);
45+
return _mm256_castps256_ps128(self);
4846
}
49-
XSIMD_INLINE void split_avx(__m256d val, __m128d& low, __m128d& high) noexcept
47+
XSIMD_INLINE __m128d lower_half(__m256d self) noexcept
5048
{
51-
low = _mm256_castpd256_pd128(val);
52-
high = _mm256_extractf128_pd(val, 1);
49+
return _mm256_castpd256_pd128(self);
50+
}
51+
XSIMD_INLINE __m128i upper_half(__m256i self) noexcept
52+
{
53+
return _mm256_extractf128_si256(self, 1);
54+
}
55+
XSIMD_INLINE __m128 upper_half(__m256 self) noexcept
56+
{
57+
return _mm256_extractf128_ps(self, 1);
58+
}
59+
XSIMD_INLINE __m128d upper_half(__m256d self) noexcept
60+
{
61+
return _mm256_extractf128_pd(self, 1);
62+
}
63+
template <class Full, class Half>
64+
XSIMD_INLINE void split_avx(Full val, Half& low, Half& high) noexcept
65+
{
66+
low = lower_half(val);
67+
high = upper_half(val);
5368
}
5469
XSIMD_INLINE __m256i merge_sse(__m128i low, __m128i high) noexcept
5570
{
@@ -63,6 +78,17 @@ namespace xsimd
6378
{
6479
return _mm256_insertf128_pd(_mm256_castpd128_pd256(low), high, 1);
6580
}
81+
template <class T>
82+
XSIMD_INLINE batch<T, sse4_2> lower_half(batch<T, avx> const& self) noexcept
83+
{
84+
return lower_half(self);
85+
}
86+
template <class T>
87+
XSIMD_INLINE batch<T, sse4_2> upper_half(batch<T, avx> const& self) noexcept
88+
{
89+
return upper_half(self);
90+
}
91+
6692
template <class F>
6793
XSIMD_INLINE __m256i fwd_to_sse(F f, __m256i self) noexcept
6894
{
@@ -865,6 +891,146 @@ namespace xsimd
865891
return _mm256_loadu_pd(mem);
866892
}
867893

894+
// load_masked
895+
template <class A, bool... Values, class Mode>
896+
XSIMD_INLINE batch<float, A> load_masked(float const* mem, batch_bool_constant<float, A, Values...> mask, convert<float>, Mode, requires_arch<avx>) noexcept
897+
{
898+
XSIMD_IF_CONSTEXPR(mask.none())
899+
{
900+
return _mm256_setzero_ps();
901+
}
902+
else XSIMD_IF_CONSTEXPR(mask.all())
903+
{
904+
return load<A>(mem, Mode {});
905+
}
906+
// confined to lower 128-bit half (4 lanes) → forward to SSE2
907+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4)
908+
{
909+
constexpr auto mlo = mask.template lower_half<sse4_2>();
910+
const auto lo = load_masked(mem, mlo, convert<float> {}, Mode {}, sse4_2 {});
911+
return batch<float, A>(detail::merge_sse(lo, batch<float, sse4_2>(0.f)));
912+
}
913+
// confined to upper 128-bit half (4 lanes) → forward to SSE2
914+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4)
915+
{
916+
constexpr auto mhi = mask.template upper_half<sse4_2>();
917+
const auto hi = load_masked(mem + 4, mhi, convert<float> {}, Mode {}, sse4_2 {});
918+
return batch<float, A>(detail::merge_sse(batch<float, sse4_2>(0.f), hi));
919+
}
920+
else
921+
{
922+
// crossing 128-bit boundary → use 256-bit masked load
923+
return _mm256_maskload_ps(mem, mask.as_batch());
924+
}
925+
}
926+
927+
template <class A, bool... Values, class Mode>
928+
XSIMD_INLINE batch<double, A> load_masked(double const* mem,
929+
batch_bool_constant<double, A, Values...> mask,
930+
convert<double>,
931+
Mode,
932+
requires_arch<avx>) noexcept
933+
{
934+
XSIMD_IF_CONSTEXPR(mask.none())
935+
{
936+
return _mm256_setzero_pd();
937+
}
938+
else XSIMD_IF_CONSTEXPR(mask.all())
939+
{
940+
return load<A>(mem, Mode {});
941+
}
942+
// confined to lower 128-bit half (2 lanes) → forward to SSE2
943+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2)
944+
{
945+
constexpr auto mlo = mask.template lower_half<sse4_2>();
946+
const auto lo = load_masked(mem, mlo, convert<double> {}, Mode {}, sse4_2 {});
947+
return batch<double, A>(detail::merge_sse(lo, batch<double, sse4_2>(0.0)));
948+
}
949+
// confined to upper 128-bit half (2 lanes) → forward to SSE2
950+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2)
951+
{
952+
constexpr auto mhi = mask.template upper_half<sse4_2>();
953+
const auto hi = load_masked(mem + 2, mhi, convert<double> {}, Mode {}, sse4_2 {});
954+
return batch<double, A>(detail::merge_sse(batch<double, sse4_2>(0.0), hi));
955+
}
956+
else
957+
{
958+
// crossing 128-bit boundary → use 256-bit masked load
959+
return _mm256_maskload_pd(mem, mask.as_batch());
960+
}
961+
}
962+
963+
// store_masked
964+
template <class A, bool... Values, class Mode>
965+
XSIMD_INLINE void store_masked(float* mem,
966+
batch<float, A> const& src,
967+
batch_bool_constant<float, A, Values...> mask,
968+
Mode,
969+
requires_arch<avx>) noexcept
970+
{
971+
XSIMD_IF_CONSTEXPR(mask.none())
972+
{
973+
return;
974+
}
975+
else XSIMD_IF_CONSTEXPR(mask.all())
976+
{
977+
src.store(mem, Mode {});
978+
}
979+
// confined to lower 128-bit half (4 lanes) → forward to SSE2
980+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4)
981+
{
982+
constexpr auto mlo = mask.template lower_half<sse4_2>();
983+
const batch<float, sse4_2> lo(_mm256_castps256_ps128(src));
984+
store_masked<sse4_2>(mem, lo, mlo, Mode {}, sse4_2 {});
985+
}
986+
// confined to upper 128-bit half (4 lanes) → forward to SSE2
987+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4)
988+
{
989+
constexpr auto mhi = mask.template upper_half<sse4_2>();
990+
const batch<float, sse4_2> hi(_mm256_extractf128_ps(src, 1));
991+
store_masked<sse4_2>(mem + 4, hi, mhi, Mode {}, sse4_2 {});
992+
}
993+
else
994+
{
995+
_mm256_maskstore_ps(mem, mask.as_batch(), src);
996+
}
997+
}
998+
999+
template <class A, bool... Values, class Mode>
1000+
XSIMD_INLINE void store_masked(double* mem,
1001+
batch<double, A> const& src,
1002+
batch_bool_constant<double, A, Values...> mask,
1003+
Mode,
1004+
requires_arch<avx>) noexcept
1005+
{
1006+
XSIMD_IF_CONSTEXPR(mask.none())
1007+
{
1008+
return;
1009+
}
1010+
else XSIMD_IF_CONSTEXPR(mask.all())
1011+
{
1012+
src.store(mem, Mode {});
1013+
}
1014+
// confined to lower 128-bit half (2 lanes) → forward to SSE2
1015+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2)
1016+
{
1017+
constexpr auto mlo = mask.template lower_half<sse2>();
1018+
const batch<double, sse2> lo(_mm256_castpd256_pd128(src));
1019+
store_masked<sse2>(mem, lo, mlo, Mode {}, sse4_2 {});
1020+
}
1021+
// confined to upper 128-bit half (2 lanes) → forward to SSE2
1022+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2)
1023+
{
1024+
constexpr auto mhi = mask.template upper_half<sse2>();
1025+
const batch<double, sse2> hi(_mm256_extractf128_pd(src, 1));
1026+
store_masked<sse2>(mem + 2, hi, mhi, Mode {}, sse4_2 {});
1027+
}
1028+
else
1029+
{
1030+
_mm256_maskstore_pd(mem, mask.as_batch(), src);
1031+
}
1032+
}
1033+
8681034
// lt
8691035
template <class A>
8701036
XSIMD_INLINE batch_bool<float, A> lt(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx>) noexcept

0 commit comments

Comments
 (0)