diff --git a/docs/source/api/data_transfer.rst b/docs/source/api/data_transfer.rst index 5691073ec..2ce62e010 100644 --- a/docs/source/api/data_transfer.rst +++ b/docs/source/api/data_transfer.rst @@ -10,7 +10,7 @@ Data transfer From memory: +---------------------------------------+----------------------------------------------------+ -| :cpp:func:`load` | load values from memory | +| :cpp:func:`load` | load values from memory (optionally masked) | +---------------------------------------+----------------------------------------------------+ | :cpp:func:`load_aligned` | load values from aligned memory | +---------------------------------------+----------------------------------------------------+ @@ -30,7 +30,7 @@ From a scalar: To memory: +---------------------------------------+----------------------------------------------------+ -| :cpp:func:`store` | store values to memory | +| :cpp:func:`store` | store values to memory (optionally masked) | +---------------------------------------+----------------------------------------------------+ | :cpp:func:`store_aligned` | store values to aligned memory | +---------------------------------------+----------------------------------------------------+ diff --git a/include/xsimd/arch/common/xsimd_common_arithmetic.hpp b/include/xsimd/arch/common/xsimd_common_arithmetic.hpp index cdeb9125b..e42e870a5 100644 --- a/include/xsimd/arch/common/xsimd_common_arithmetic.hpp +++ b/include/xsimd/arch/common/xsimd_common_arithmetic.hpp @@ -16,6 +16,7 @@ #include #include +#include "../../types/xsimd_batch_constant.hpp" #include "./xsimd_common_details.hpp" namespace xsimd diff --git a/include/xsimd/arch/common/xsimd_common_memory.hpp b/include/xsimd/arch/common/xsimd_common_memory.hpp index 4ad148a6f..473bfcccf 100644 --- a/include/xsimd/arch/common/xsimd_common_memory.hpp +++ b/include/xsimd/arch/common/xsimd_common_memory.hpp @@ -13,6 +13,7 @@ #define XSIMD_COMMON_MEMORY_HPP #include +#include #include #include @@ -348,6 +349,102 @@ namespace xsimd return detail::load_unaligned(mem, cvt, common {}, detail::conversion_type {}); } + template + XSIMD_INLINE batch load(T const* mem, aligned_mode, requires_arch) noexcept + { + return load_aligned(mem, convert {}, A {}); + } + + template + XSIMD_INLINE batch load(T const* mem, unaligned_mode, requires_arch) noexcept + { + return load_unaligned(mem, convert {}, A {}); + } + + template + XSIMD_INLINE batch + load_masked(T_in const* mem, batch_bool_constant, convert, alignment, requires_arch) noexcept + { + constexpr std::size_t size = batch::size; + alignas(A::alignment()) std::array buffer {}; + constexpr std::array mask { Values... }; + + for (std::size_t i = 0; i < size; ++i) + buffer[i] = mask[i] ? static_cast(mem[i]) : T_out(0); + + return batch::load(buffer.data(), aligned_mode {}); + } + + template + XSIMD_INLINE void + store_masked(T_out* mem, batch const& src, batch_bool_constant, alignment, requires_arch) noexcept + { + constexpr std::size_t size = batch::size; + constexpr std::array mask { Values... }; + + for (std::size_t i = 0; i < size; ++i) + if (mask[i]) + { + mem[i] = static_cast(src.get(i)); + } + } + + template + XSIMD_INLINE batch load_masked(int32_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto f = load_masked(reinterpret_cast(mem), batch_bool_constant {}, convert {}, Mode {}, A {}); + return bitwise_cast(f); + } + + template + XSIMD_INLINE batch load_masked(uint32_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto f = load_masked(reinterpret_cast(mem), batch_bool_constant {}, convert {}, Mode {}, A {}); + return bitwise_cast(f); + } + + template + XSIMD_INLINE typename std::enable_if::value, batch>::type + load_masked(int64_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto d = load_masked(reinterpret_cast(mem), batch_bool_constant {}, convert {}, Mode {}, A {}); + return bitwise_cast(d); + } + + template + XSIMD_INLINE typename std::enable_if::value, batch>::type + load_masked(uint64_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto d = load_masked(reinterpret_cast(mem), batch_bool_constant {}, convert {}, Mode {}, A {}); + return bitwise_cast(d); + } + + template + XSIMD_INLINE void store_masked(int32_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + store_masked(reinterpret_cast(mem), bitwise_cast(src), batch_bool_constant {}, Mode {}, A {}); + } + + template + XSIMD_INLINE void store_masked(uint32_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + store_masked(reinterpret_cast(mem), bitwise_cast(src), batch_bool_constant {}, Mode {}, A {}); + } + + template + XSIMD_INLINE typename std::enable_if::value, void>::type + store_masked(int64_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + store_masked(reinterpret_cast(mem), bitwise_cast(src), batch_bool_constant {}, Mode {}, A {}); + } + + template + XSIMD_INLINE typename std::enable_if::value, void>::type + store_masked(uint64_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + store_masked(reinterpret_cast(mem), bitwise_cast(src), batch_bool_constant {}, Mode {}, A {}); + } + // rotate_right template XSIMD_INLINE batch rotate_right(batch const& self, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx.hpp b/include/xsimd/arch/xsimd_avx.hpp index 9d93be071..0bd1facbb 100644 --- a/include/xsimd/arch/xsimd_avx.hpp +++ b/include/xsimd/arch/xsimd_avx.hpp @@ -36,20 +36,35 @@ namespace xsimd namespace detail { - XSIMD_INLINE void split_avx(__m256i val, __m128i& low, __m128i& high) noexcept + XSIMD_INLINE __m128i lower_half(__m256i self) noexcept { - low = _mm256_castsi256_si128(val); - high = _mm256_extractf128_si256(val, 1); + return _mm256_castsi256_si128(self); } - XSIMD_INLINE void split_avx(__m256 val, __m128& low, __m128& high) noexcept + XSIMD_INLINE __m128 lower_half(__m256 self) noexcept { - low = _mm256_castps256_ps128(val); - high = _mm256_extractf128_ps(val, 1); + return _mm256_castps256_ps128(self); } - XSIMD_INLINE void split_avx(__m256d val, __m128d& low, __m128d& high) noexcept + XSIMD_INLINE __m128d lower_half(__m256d self) noexcept { - low = _mm256_castpd256_pd128(val); - high = _mm256_extractf128_pd(val, 1); + return _mm256_castpd256_pd128(self); + } + XSIMD_INLINE __m128i upper_half(__m256i self) noexcept + { + return _mm256_extractf128_si256(self, 1); + } + XSIMD_INLINE __m128 upper_half(__m256 self) noexcept + { + return _mm256_extractf128_ps(self, 1); + } + XSIMD_INLINE __m128d upper_half(__m256d self) noexcept + { + return _mm256_extractf128_pd(self, 1); + } + template + XSIMD_INLINE void split_avx(Full val, Half& low, Half& high) noexcept + { + low = lower_half(val); + high = upper_half(val); } XSIMD_INLINE __m256i merge_sse(__m128i low, __m128i high) noexcept { @@ -63,6 +78,17 @@ namespace xsimd { return _mm256_insertf128_pd(_mm256_castpd128_pd256(low), high, 1); } + template + XSIMD_INLINE batch lower_half(batch const& self) noexcept + { + return lower_half(self); + } + template + XSIMD_INLINE batch upper_half(batch const& self) noexcept + { + return upper_half(self); + } + template XSIMD_INLINE __m256i fwd_to_sse(F f, __m256i self) noexcept { @@ -865,6 +891,134 @@ namespace xsimd return _mm256_loadu_pd(mem); } + // load_masked + template + XSIMD_INLINE batch load_masked(float const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm256_setzero_ps(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + // confined to lower 128-bit half (4 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, sse4_2 {}); + return batch(detail::merge_sse(lo, batch(0.f))); + } + // confined to upper 128-bit half (4 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 4, mhi, convert {}, Mode {}, sse4_2 {}); + return batch(detail::merge_sse(batch(0.f), hi)); + } + else + { + // crossing 128-bit boundary → use 256-bit masked load + return _mm256_maskload_ps(mem, mask.as_batch()); + } + } + + template + XSIMD_INLINE batch load_masked(double const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm256_setzero_pd(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + // confined to lower 128-bit half (2 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, sse4_2 {}); + return batch(detail::merge_sse(lo, batch(0.0))); + } + // confined to upper 128-bit half (2 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 2, mhi, convert {}, Mode {}, sse4_2 {}); + return batch(detail::merge_sse(batch(0.0), hi)); + } + else + { + // crossing 128-bit boundary → use 256-bit masked load + return _mm256_maskload_pd(mem, mask.as_batch()); + } + } + + // store_masked + template + XSIMD_INLINE void store_masked(float* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + // confined to lower 128-bit half (4 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) + { + constexpr auto mlo = mask.template lower_half(); + const batch lo(_mm256_castps256_ps128(src)); + store_masked(mem, lo, mlo, Mode {}, sse4_2 {}); + } + // confined to upper 128-bit half (4 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const batch hi(_mm256_extractf128_ps(src, 1)); + store_masked(mem + 4, hi, mhi, Mode {}, sse4_2 {}); + } + else + { + _mm256_maskstore_ps(mem, mask.as_batch(), src); + } + } + + template + XSIMD_INLINE void store_masked(double* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + // confined to lower 128-bit half (2 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2) + { + constexpr auto mlo = mask.template lower_half(); + const batch lo(_mm256_castpd256_pd128(src)); + store_masked(mem, lo, mlo, Mode {}, sse4_2 {}); + } + // confined to upper 128-bit half (2 lanes) → forward to SSE2 + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2) + { + constexpr auto mhi = mask.template upper_half(); + const batch hi(_mm256_extractf128_pd(src, 1)); + store_masked(mem + 2, hi, mhi, Mode {}, sse4_2 {}); + } + else + { + _mm256_maskstore_pd(mem, mask.as_batch(), src); + } + } + // lt template XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx2.hpp b/include/xsimd/arch/xsimd_avx2.hpp index cf9669edf..cc88f98ce 100644 --- a/include/xsimd/arch/xsimd_avx2.hpp +++ b/include/xsimd/arch/xsimd_avx2.hpp @@ -116,6 +116,160 @@ namespace xsimd } } + // load_masked + template + XSIMD_INLINE batch load_masked(int32_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm256_setzero_si256(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + // confined to lower 128-bit half (4 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, sse4_2 {}); + return _mm256_zextsi128_si256(lo.data); + } + // confined to upper 128-bit half (4 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 4, mhi, convert {}, Mode {}, sse4_2 {}); + return _mm256_insertf128_si256(_mm256_setzero_si256(), hi.data, 1); + } + else + { + return _mm256_maskload_epi32(mem, mask.as_batch()); + } + } + + template + XSIMD_INLINE batch load_masked(uint32_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto r = load_masked(reinterpret_cast(mem), batch_bool_constant {}, convert {}, Mode {}, avx2 {}); + return bitwise_cast(r); + } + + template + XSIMD_INLINE batch load_masked(int64_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm256_setzero_si256(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + // confined to lower 128-bit half (2 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, sse4_2 {}); + return _mm256_zextsi128_si256(lo.data); + } + // confined to upper 128-bit half (2 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 2, mhi, convert {}, Mode {}, sse4_2 {}); + return _mm256_insertf128_si256(_mm256_setzero_si256(), hi, 1); + } + else + { + return _mm256_maskload_epi64(reinterpret_cast(mem), mask.as_batch()); + } + } + + template + XSIMD_INLINE batch load_masked(uint64_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto r = load_masked(reinterpret_cast(mem), batch_bool_constant {}, convert {}, Mode {}, avx2 {}); + return bitwise_cast(r); + } + + // store_masked + template + XSIMD_INLINE void store_masked(int32_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + // confined to lower 128-bit half (4 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) + { + constexpr auto mlo = mask.template lower_half(); + const batch lo(_mm256_castsi256_si128(src)); + store_masked(mem, lo, mlo, Mode {}, sse4_2 {}); + } + // confined to upper 128-bit half (4 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const batch hi(_mm256_extractf128_si256(src, 1)); + store_masked(mem + 4, hi, mhi, Mode {}, sse4_2 {}); + } + else + { + _mm256_maskstore_epi32(reinterpret_cast(mem), mask.as_batch(), src); + } + } + + template + XSIMD_INLINE void store_masked(uint32_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + const auto s32 = bitwise_cast(src); + store_masked(reinterpret_cast(mem), s32, mask, Mode {}, avx2 {}); + } + + template + XSIMD_INLINE void store_masked(int64_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + // confined to lower 128-bit half (2 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2) + { + constexpr auto mlo = mask.template lower_half(); + const batch lo(_mm256_castsi256_si128(src)); + store_masked(mem, lo, mlo, Mode {}, sse4_2 {}); + } + // confined to upper 128-bit half (2 lanes) → forward to SSE + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2) + { + constexpr auto mhi = mask.template upper_half(); + const batch hi(_mm256_extractf128_si256(src, 1)); + store_masked(mem + 2, hi, mhi, Mode {}, sse4_2 {}); + } + else + { + _mm256_maskstore_epi64(reinterpret_cast(mem), mask.as_batch(), src); + } + } + + template + XSIMD_INLINE void store_masked(uint64_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + const auto s64 = bitwise_cast(src); + store_masked(reinterpret_cast(mem), s64, batch_bool_constant {}, Mode {}, avx2 {}); + } + // bitwise_and template ::value, void>::type> XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx512dq.hpp b/include/xsimd/arch/xsimd_avx512dq.hpp index 063affa4c..ac328e316 100644 --- a/include/xsimd/arch/xsimd_avx512dq.hpp +++ b/include/xsimd/arch/xsimd_avx512dq.hpp @@ -21,6 +21,31 @@ namespace xsimd { using namespace types; + // load_masked + template + XSIMD_INLINE batch load_masked(int32_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 8, mhi, convert {}, Mode {}, avx2 {}); + return _mm512_inserti32x8(_mm512_setzero_si512(), hi, 1); + } + return load_masked(mem, mask, convert {}, Mode {}, avx512f {}); + } + + template + XSIMD_INLINE batch load_masked(float const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 8, mhi, convert {}, Mode {}, avx2 {}); + return _mm512_insertf32x8(_mm512_setzero_ps(), hi, 1); + } + return load_masked(mem, mask, convert {}, Mode {}, avx512f {}); + } + // bitwise_and template XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_avx512f.hpp b/include/xsimd/arch/xsimd_avx512f.hpp index 4daf0a02f..adb1fbc3c 100644 --- a/include/xsimd/arch/xsimd_avx512f.hpp +++ b/include/xsimd/arch/xsimd_avx512f.hpp @@ -17,6 +17,7 @@ #include #include "../types/xsimd_avx512f_register.hpp" +#include "../types/xsimd_batch_constant.hpp" namespace xsimd { @@ -41,20 +42,35 @@ namespace xsimd namespace detail { - XSIMD_INLINE void split_avx512(__m512 val, __m256& low, __m256& high) noexcept + XSIMD_INLINE __m256 lower_half(__m512 self) noexcept { - low = _mm512_castps512_ps256(val); - high = _mm512_extractf32x8_ps(val, 1); + return _mm512_castps512_ps256(self); } - XSIMD_INLINE void split_avx512(__m512d val, __m256d& low, __m256d& high) noexcept + XSIMD_INLINE __m256d lower_half(__m512d self) noexcept { - low = _mm512_castpd512_pd256(val); - high = _mm512_extractf64x4_pd(val, 1); + return _mm512_castpd512_pd256(self); } - XSIMD_INLINE void split_avx512(__m512i val, __m256i& low, __m256i& high) noexcept + XSIMD_INLINE __m256i lower_half(__m512i self) noexcept { - low = _mm512_castsi512_si256(val); - high = _mm512_extracti64x4_epi64(val, 1); + return _mm512_castsi512_si256(self); + } + XSIMD_INLINE __m256 upper_half(__m512 self) noexcept + { + return _mm512_extractf32x8_ps(self, 1); + } + XSIMD_INLINE __m256d upper_half(__m512d self) noexcept + { + return _mm512_extractf64x4_pd(self, 1); + } + XSIMD_INLINE __m256i upper_half(__m512i self) noexcept + { + return _mm512_extracti64x4_epi64(self, 1); + } + template + XSIMD_INLINE void split_avx512(Full const& val, Half& low, Half& high) noexcept + { + low = lower_half(val); + high = upper_half(val); } XSIMD_INLINE __m512i merge_avx(__m256i low, __m256i high) noexcept { @@ -68,6 +84,7 @@ namespace xsimd { return _mm512_insertf64x4(_mm512_castpd256_pd512(low), high, 1); } + template __m512i fwd_to_avx(F f, __m512i self) { @@ -229,6 +246,91 @@ namespace xsimd } } + // load_masked + template + XSIMD_INLINE batch load_masked(int32_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm512_setzero_si512(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 8) // Forward to AVX2 when confined to a 256-bit half (8 lanes) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, avx2 {}); + return _mm512_zextsi256_si512(lo); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + return _mm512_maskz_load_epi32(mask.mask(), mem); + } + else + { + return _mm512_maskz_loadu_epi32(mask.mask(), mem); + } + } + } + + template + XSIMD_INLINE batch load_masked(uint32_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto r = load_masked(reinterpret_cast(mem), + batch_bool_constant {}, + convert {}, Mode {}, avx512f {}); + return bitwise_cast(r); + } + + template + XSIMD_INLINE batch load_masked(int64_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm512_setzero_si512(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) // Forward to AVX2 when confined to a 256-bit half (4 lanes) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, avx2 {}); + return _mm512_zextsi256_si512(lo); + } + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 4, mhi, convert {}, Mode {}, avx2 {}); + return _mm512_inserti64x4(_mm512_setzero_si512(), hi, 1); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + return _mm512_maskz_load_epi64(mask.mask(), mem); + } + else + { + return _mm512_maskz_loadu_epi64(mask.mask(), mem); + } + } + } + + template + XSIMD_INLINE batch load_masked(uint64_t const* mem, batch_bool_constant, convert, Mode, requires_arch) noexcept + { + const auto r = load_masked(reinterpret_cast(mem), + batch_bool_constant {}, + convert {}, Mode {}, avx512f {}); + return bitwise_cast(r); + } + // abs template XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept @@ -866,11 +968,11 @@ namespace xsimd XSIMD_IF_CONSTEXPR(sizeof(T) == 4) { - return _mm512_mask_sub_epi32(self, mask.data, self, _mm512_set1_epi32(1)); + return _mm512_mask_sub_epi32(self, mask, self, _mm512_set1_epi32(1)); } else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) { - return _mm512_mask_sub_epi64(self, mask.data, self, _mm512_set1_epi64(1)); + return _mm512_mask_sub_epi64(self, mask, self, _mm512_set1_epi64(1)); } else { @@ -1391,6 +1493,238 @@ namespace xsimd return _mm512_loadu_pd(mem); } + // load_masked + template + XSIMD_INLINE batch load_masked(float const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm512_setzero_ps(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 8) // Forward to AVX2 when confined to a 256-bit half + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, avx2 {}); + return _mm512_zextps256_ps512(lo); + } + else + { + // Last resort: 512-bit masked load + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + return _mm512_maskz_load_ps(static_cast<__mmask16>(mask.mask()), mem); + } + else + { + return _mm512_maskz_loadu_ps(static_cast<__mmask16>(mask.mask()), mem); + } + } + } + + template + XSIMD_INLINE batch load_masked(double const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm512_setzero_pd(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) // Forward to AVX2 when confined to a 256-bit half + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = load_masked(mem, mlo, convert {}, Mode {}, avx2 {}); + return _mm512_zextpd256_pd512(lo); + } + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = load_masked(mem + 4, mhi, convert {}, Mode {}, avx2 {}); + return _mm512_insertf64x4(_mm512_setzero_pd(), hi, 1); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + return _mm512_maskz_load_pd(mask.mask(), mem); + } + else + { + return _mm512_maskz_loadu_pd(mask.mask(), mem); + } + } + } + + // store_masked + template + XSIMD_INLINE void store_masked(float* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 8) // Forward to AVX2 when confined to a 256-bit half + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = detail::lower_half(src); + store_masked(mem, lo, mlo, Mode {}, avx2 {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = detail::upper_half(src); + store_masked(mem + 8, hi, mhi, Mode {}, avx2 {}); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + _mm512_mask_store_ps(mem, mask.mask(), src); + } + else + { + _mm512_mask_storeu_ps(mem, mask.mask(), src); + } + } + } + + template + XSIMD_INLINE void store_masked(double* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) // Forward to AVX2 when confined to a 256-bit half + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = detail::lower_half(src); + store_masked(mem, lo, mlo, Mode {}, avx2 {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = detail::upper_half(src); + store_masked(mem + 4, hi, mhi, Mode {}, avx2 {}); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + _mm512_mask_store_pd(mem, mask.mask(), src); + } + else + { + _mm512_mask_storeu_pd(mem, mask.mask(), src); + } + } + } + + // store_masked + template + XSIMD_INLINE void store_masked(int32_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 8) // Forward to AVX2 when confined to a 256-bit half (8 lanes) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = detail::lower_half(src); + store_masked(mem, lo, mlo, Mode {}, avx2 {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = detail::upper_half(src); + store_masked(mem + 8, hi, mhi, Mode {}, avx2 {}); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + _mm512_mask_store_epi32(mem, mask.mask(), src); + } + else + { + _mm512_mask_storeu_epi32(mem, mask.mask(), src); + } + } + } + + template + XSIMD_INLINE void store_masked(uint32_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + auto s32 = bitwise_cast(src); + store_masked(reinterpret_cast(mem), s32, + batch_bool_constant {}, + Mode {}, avx512f {}); + } + + template + XSIMD_INLINE void store_masked(int64_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4) // Forward to AVX2 when confined to a 256-bit half (4 lanes) + { + constexpr auto mlo = mask.template lower_half(); + const auto lo = detail::lower_half(src); + store_masked(mem, lo, mlo, Mode {}, avx2 {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4) + { + constexpr auto mhi = mask.template upper_half(); + const auto hi = detail::upper_half(src); + store_masked(mem + 4, hi, mhi, Mode {}, avx2 {}); + } + else + { + XSIMD_IF_CONSTEXPR(std::is_same::value) + { + _mm512_mask_store_epi64(mem, mask.mask(), src); + } + else + { + _mm512_mask_storeu_epi64(mem, mask.mask(), src); + } + } + } + + template + XSIMD_INLINE void store_masked(uint64_t* mem, batch const& src, batch_bool_constant, Mode, requires_arch) noexcept + { + auto s64 = bitwise_cast(src); + store_masked(reinterpret_cast(mem), s64, + batch_bool_constant {}, + Mode {}, avx512f {}); + } + // lt template XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept diff --git a/include/xsimd/arch/xsimd_common_fwd.hpp b/include/xsimd/arch/xsimd_common_fwd.hpp index ab664da22..690a3d4b8 100644 --- a/include/xsimd/arch/xsimd_common_fwd.hpp +++ b/include/xsimd/arch/xsimd_common_fwd.hpp @@ -13,12 +13,33 @@ #ifndef XSIMD_COMMON_FWD_HPP #define XSIMD_COMMON_FWD_HPP -#include "../types/xsimd_batch_constant.hpp" - +#include #include namespace xsimd { + // Minimal forward declarations used in this header + template + class batch; + template + class batch_bool; + template + struct batch_constant; + template + struct batch_bool_constant; + template + struct convert; + template + struct requires_arch; + struct aligned_mode; + struct unaligned_mode; + + namespace types + { + template + struct has_simd_register; + } + namespace kernel { // forward declaration @@ -52,6 +73,31 @@ namespace xsimd XSIMD_INLINE batch rotr(batch const& self, STy other, requires_arch) noexcept; template XSIMD_INLINE batch rotr(batch const& self, requires_arch) noexcept; + template + XSIMD_INLINE batch load(T const* mem, aligned_mode, requires_arch) noexcept; + template + XSIMD_INLINE batch load(T const* mem, unaligned_mode, requires_arch) noexcept; + template + XSIMD_INLINE batch load_masked(T_in const* mem, batch_bool_constant mask, convert, alignment, requires_arch) noexcept; + template + XSIMD_INLINE void store_masked(T_out* mem, batch const& src, batch_bool_constant mask, alignment, requires_arch) noexcept; + template + XSIMD_INLINE batch load_masked(int32_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept; + template + XSIMD_INLINE batch load_masked(uint32_t const* mem, batch_bool_constant mask, convert, Mode, requires_arch) noexcept; + template + XSIMD_INLINE typename std::enable_if::value, batch>::type load_masked(int64_t const*, batch_bool_constant, convert, Mode, requires_arch) noexcept; + template + XSIMD_INLINE typename std::enable_if::value, batch>::type load_masked(uint64_t const*, batch_bool_constant, convert, Mode, requires_arch) noexcept; + template + XSIMD_INLINE void store_masked(int32_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept; + template + XSIMD_INLINE void store_masked(uint32_t* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept; + template + XSIMD_INLINE typename std::enable_if::value, void>::type store_masked(int64_t*, batch const&, batch_bool_constant, Mode, requires_arch) noexcept; + template + XSIMD_INLINE typename std::enable_if::value, void>::type store_masked(uint64_t*, batch const&, batch_bool_constant, Mode, requires_arch) noexcept; + // Forward declarations for pack-level helpers namespace detail { diff --git a/include/xsimd/arch/xsimd_sse2.hpp b/include/xsimd/arch/xsimd_sse2.hpp index 22f3cdf99..bcef1020c 100644 --- a/include/xsimd/arch/xsimd_sse2.hpp +++ b/include/xsimd/arch/xsimd_sse2.hpp @@ -1042,7 +1042,6 @@ namespace xsimd { return _mm_loadu_pd(mem); } - // load batch_bool template @@ -1063,6 +1062,107 @@ namespace xsimd return { load_unaligned(mem, batch_bool {}, r).data }; } + // load_masked + template + XSIMD_INLINE batch load_masked(float const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm_setzero_ps(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + { + return _mm_loadl_pi(_mm_setzero_ps(), reinterpret_cast<__m64 const*>(mem)); + } + else XSIMD_IF_CONSTEXPR(mask.countl_one() == 2) + { + return _mm_loadh_pi(_mm_setzero_ps(), reinterpret_cast<__m64 const*>(mem + 2)); + } + else + { + return load_masked(mem, mask, convert {}, Mode {}, common {}); + } + } + template + XSIMD_INLINE batch load_masked(double const* mem, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return _mm_setzero_pd(); + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + return load(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) + { + return _mm_move_sd(_mm_setzero_pd(), _mm_load_sd(mem)); + } + else XSIMD_IF_CONSTEXPR(mask.countl_one() == 1) + { + return _mm_loadh_pd(_mm_setzero_pd(), mem + 1); + } + else + { + return load_masked(mem, mask, convert {}, Mode {}, common {}); + } + } + + // store_masked + template + XSIMD_INLINE void store_masked(float* mem, batch const& src, batch_bool_constant mask, Mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, Mode {}); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + { + _mm_storel_pi(reinterpret_cast<__m64*>(mem), src); + } + else XSIMD_IF_CONSTEXPR(mask.countl_one() == 2) + { + _mm_storeh_pi(reinterpret_cast<__m64*>(mem + 2), src); + } + else + { + store_masked(mem, src, mask, Mode {}, common {}); + } + } + + template + XSIMD_INLINE void store_masked(double* mem, batch const& src, batch_bool_constant mask, Mode mode, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + src.store(mem, mode); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 1) + { + _mm_store_sd(mem, src); + } + else XSIMD_IF_CONSTEXPR(mask.countl_one() == 1) + { + _mm_storeh_pd(mem + 1, src); + } + else + { + store_masked(mem, src, mask, Mode {}, common {}); + } + } + // load_complex namespace detail { @@ -2092,6 +2192,37 @@ namespace xsimd { return _mm_unpacklo_pd(self, other); } + + // store_masked + template + XSIMD_INLINE void store_masked(float* mem, + batch const& src, + batch_bool_constant mask, + aligned_mode, + requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(mask.none()) + { + return; + } + else XSIMD_IF_CONSTEXPR(mask.all()) + { + _mm_store_ps(mem, src); + } + else XSIMD_IF_CONSTEXPR(mask.countr_one() == 2) + { + _mm_storel_pi(reinterpret_cast<__m64*>(mem), src); + } + else XSIMD_IF_CONSTEXPR(mask.countl_one() == 2) + { + _mm_storeh_pi(reinterpret_cast<__m64*>(mem + 2), src); + } + else + { + store_masked(mem, src, mask, requires_arch {}); + } + } + } } diff --git a/include/xsimd/types/xsimd_api.hpp b/include/xsimd/types/xsimd_api.hpp index afaf2cdf1..de29f59da 100644 --- a/include/xsimd/types/xsimd_api.hpp +++ b/include/xsimd/types/xsimd_api.hpp @@ -1416,6 +1416,45 @@ namespace xsimd return load_as(ptr, unaligned_mode {}); } + /** + * @ingroup batch_data_transfer + * + * Creates a batch from the buffer \c ptr using a mask. Elements + * corresponding to \c false in the mask are not accessed in memory and are + * zero-initialized in the resulting batch. + * @param ptr the memory buffer to read + * @param mask selection mask for the elements to load + * @return a new batch instance + */ + template + XSIMD_INLINE batch load(From const* ptr, + batch_bool_constant const& mask, + aligned_mode = {}) noexcept + { + detail::static_check_supported_config(); + return batch::load(ptr, mask, aligned_mode {}); + } + + /** + * @ingroup batch_data_transfer + * + * Creates a batch from the buffer \c ptr using a mask. Elements + * corresponding to \c false in the mask are not accessed in memory and are + * zero-initialized in the resulting batch. + * @param ptr the memory buffer to read. The buffer does not need to be + * aligned. + * @param mask selection mask for the elements to load + * @return a new batch instance + */ + template + XSIMD_INLINE batch load(From const* ptr, + batch_bool_constant const& mask, + unaligned_mode) noexcept + { + detail::static_check_supported_config(); + return batch::load(ptr, mask, unaligned_mode {}); + } + /** * @ingroup batch_data_transfer * @@ -2413,6 +2452,46 @@ namespace xsimd store_as(mem, val, unaligned_mode {}); } + /** + * @ingroup batch_data_transfer + * + * Copy selected elements of batch \c val to the buffer \c mem using + * a mask. Elements corresponding to \c false in the mask are not + * written to memory. + * @param mem the memory buffer to write to + * @param val the batch to copy from + * @param mask selection mask for the elements to store + */ + template + XSIMD_INLINE void store(T* mem, + batch const& val, + batch_bool_constant const& mask, + aligned_mode = {}) noexcept + { + detail::static_check_supported_config(); + val.store(mem, mask, aligned_mode {}); + } + + /** + * @ingroup batch_data_transfer + * + * Copy selected elements of batch \c val to the buffer \c mem using a mask. + * Elements corresponding to \c false in the mask are not written to memory. + * @param mem the memory buffer to write to. The buffer does not need to be + * aligned. + * @param val the batch to copy from + * @param mask selection mask for the elements to store + */ + template + XSIMD_INLINE void store(T* mem, + batch const& val, + batch_bool_constant const& mask, + unaligned_mode) noexcept + { + detail::static_check_supported_config(); + val.store(mem, mask, unaligned_mode {}); + } + /** * @ingroup batch_data_transfer * diff --git a/include/xsimd/types/xsimd_batch.hpp b/include/xsimd/types/xsimd_batch.hpp index b3b704666..b3d4fa56e 100644 --- a/include/xsimd/types/xsimd_batch.hpp +++ b/include/xsimd/types/xsimd_batch.hpp @@ -21,6 +21,8 @@ namespace xsimd { + template + struct batch_bool_constant; template class batch; @@ -143,6 +145,12 @@ namespace xsimd template XSIMD_INLINE void store(U* mem, unaligned_mode) const noexcept; + // Compile-time mask overloads + template + XSIMD_INLINE void store(U* mem, batch_bool_constant mask, aligned_mode) const noexcept; + template + XSIMD_INLINE void store(U* mem, batch_bool_constant mask, unaligned_mode) const noexcept; + template XSIMD_NO_DISCARD static XSIMD_INLINE batch load_aligned(U const* mem) noexcept; template @@ -151,6 +159,11 @@ namespace xsimd XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, aligned_mode) noexcept; template XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, unaligned_mode) noexcept; + // Compile-time mask overloads + template + XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, aligned_mode = {}) noexcept; + template + XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, unaligned_mode) noexcept; template XSIMD_NO_DISCARD static XSIMD_INLINE batch gather(U const* src, batch const& index) noexcept; @@ -403,10 +416,20 @@ namespace xsimd XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, aligned_mode) noexcept; template XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, unaligned_mode) noexcept; + // Compile-time mask overloads + template + XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, aligned_mode = {}) noexcept; + template + XSIMD_NO_DISCARD static XSIMD_INLINE batch load(U const* mem, batch_bool_constant mask, unaligned_mode) noexcept; template XSIMD_INLINE void store(U* mem, aligned_mode) const noexcept; template XSIMD_INLINE void store(U* mem, unaligned_mode) const noexcept; + // Compile-time mask overloads + template + XSIMD_INLINE void store(U* mem, batch_bool_constant mask, aligned_mode) const noexcept; + template + XSIMD_INLINE void store(U* mem, batch_bool_constant mask, unaligned_mode) const noexcept; XSIMD_INLINE real_batch real() const noexcept; XSIMD_INLINE real_batch imag() const noexcept; @@ -616,6 +639,8 @@ namespace xsimd return store_unaligned(mem); } + // masked store free functions are provided in xsimd_api.hpp + /** * Loading from aligned memory. May involve a conversion if \c U is different * from \c T. @@ -664,6 +689,46 @@ namespace xsimd return load_unaligned(mem); } + template + template + XSIMD_INLINE batch batch::load(U const* mem, + batch_bool_constant mask, + aligned_mode) noexcept + { + detail::static_check_supported_config(); + return kernel::load_masked(mem, mask, kernel::convert {}, aligned_mode {}, A {}); + } + + template + template + XSIMD_INLINE batch batch::load(U const* mem, + batch_bool_constant mask, + unaligned_mode) noexcept + { + detail::static_check_supported_config(); + return kernel::load_masked(mem, mask, kernel::convert {}, unaligned_mode {}, A {}); + } + + template + template + XSIMD_INLINE void batch::store(U* mem, + batch_bool_constant mask, + aligned_mode) const noexcept + { + detail::static_check_supported_config(); + kernel::store_masked(mem, *this, mask, aligned_mode {}, A {}); + } + + template + template + XSIMD_INLINE void batch::store(U* mem, + batch_bool_constant mask, + unaligned_mode) const noexcept + { + detail::static_check_supported_config(); + kernel::store_masked(mem, *this, mask, unaligned_mode {}, A {}); + } + /** * Create a new batch gathering elements starting at address \c src and * offset by each element in \c index. @@ -1217,6 +1282,21 @@ namespace xsimd return kernel::store_complex_unaligned(dst, *this, A {}); } + // Compile-time mask overloads for complex store + template + template + XSIMD_INLINE void batch, A>::store(U* mem, batch_bool_constant mask, aligned_mode) const noexcept + { + kernel::store_masked(mem, *this, mask, aligned_mode {}, A {}); + } + + template + template + XSIMD_INLINE void batch, A>::store(U* mem, batch_bool_constant mask, unaligned_mode) const noexcept + { + kernel::store_masked(mem, *this, mask, unaligned_mode {}, A {}); + } + template XSIMD_INLINE void batch, A>::store_aligned(T* real_dst, T* imag_dst) const noexcept { @@ -1245,6 +1325,25 @@ namespace xsimd return load_unaligned(mem); } + // Compile-time mask overloads for complex load + template + template + XSIMD_INLINE batch, A> batch, A>::load(U const* mem, + batch_bool_constant mask, + aligned_mode) noexcept + { + return kernel::load_masked(mem, mask, kernel::convert {}, aligned_mode {}, A {}); + } + + template + template + XSIMD_INLINE batch, A> batch, A>::load(U const* mem, + batch_bool_constant mask, + unaligned_mode) noexcept + { + return kernel::load_masked(mem, mask, kernel::convert {}, unaligned_mode {}, A {}); + } + template template XSIMD_INLINE void batch, A>::store(U* mem, aligned_mode) const noexcept diff --git a/include/xsimd/types/xsimd_batch_constant.hpp b/include/xsimd/types/xsimd_batch_constant.hpp index 35cc5d3af..20423715b 100644 --- a/include/xsimd/types/xsimd_batch_constant.hpp +++ b/include/xsimd/types/xsimd_batch_constant.hpp @@ -12,6 +12,8 @@ #ifndef XSIMD_BATCH_CONSTANT_HPP #define XSIMD_BATCH_CONSTANT_HPP +#include + #include "./xsimd_batch.hpp" #include "./xsimd_utils.hpp" @@ -39,6 +41,10 @@ namespace xsimd */ constexpr batch_type as_batch_bool() const noexcept { return { Values... }; } + /** + * @brief Generate a batch of @p integers from this @p batch_bool_constant + */ + constexpr batch, A> as_batch() const noexcept { return { -as_integer_t(Values)... }; } // the minus is important! /** * @brief Generate a batch of @p batch_type from this @p batch_bool_constant */ @@ -54,6 +60,41 @@ namespace xsimd return mask_helper(0, static_cast(Values)...); } + static constexpr bool none() noexcept + { + return truncated_mask() == 0u; + } + + static constexpr bool any() noexcept + { + return !none(); + } + + static constexpr bool all() noexcept + { + return truncated_mask() == low_mask(size); + } + + static constexpr std::size_t countr_zero() noexcept + { + return countr_zero_impl(truncated_mask(), size); + } + + static constexpr std::size_t countl_zero() noexcept + { + return countl_zero_impl(truncated_mask(), size); + } + + static constexpr std::size_t countr_one() noexcept + { + return countr_one_impl(truncated_mask(), size); + } + + static constexpr std::size_t countl_one() noexcept + { + return countl_one_impl(truncated_mask(), size); + } + private: static constexpr int mask_helper(int acc) noexcept { return acc; } @@ -91,6 +132,13 @@ namespace xsimd return apply...>, std::tuple...>>(detail::make_index_sequence()); } + template + static constexpr auto splice_impl(detail::index_sequence) noexcept + -> batch_bool_constant...>>::type::value...> + { + return {}; + } + public: #define MAKE_BINARY_OP(OP, NAME) \ template \ @@ -117,8 +165,84 @@ namespace xsimd { return {}; } - }; + // splice: build a sub-constant [Begin, End) targeting another arch A2 + // Useful to forward compile-time masks to narrower lanes (e.g., AVX512 -> AVX halves). + template = Begin ? (End - Begin) : 0)> + static constexpr auto splice() noexcept + -> decltype(splice_impl(detail::make_index_sequence())) + { + static_assert(Begin <= End, "splice: Begin must be <= End"); + static_assert(End <= size, "splice: End must be <= size"); + static_assert(N == batch_bool::size, "splice: target arch size must match submask length"); + return splice_impl(detail::make_index_sequence()); + } + + // Convenience helpers for half splits (require even size and appropriate target arch) + template + static constexpr auto lower_half() noexcept + -> decltype(splice()) + { + static_assert(size % 2 == 0, "lower_half requires even size"); + return splice(); + } + + template + static constexpr auto upper_half() noexcept + -> decltype(splice()) + { + static_assert(size % 2 == 0, "upper_half requires even size"); + return splice(); + } + + private: + // Build a 64-bit mask from Values... (LSB = index 0) + template + struct build_bits_helper; + + template + struct build_bits_helper + { + static constexpr uint64_t value = 0u; + }; + + template + struct build_bits_helper + { + static constexpr uint64_t value = (Current ? (uint64_t(1) << I) : 0u) + | build_bits_helper::value; + }; + + static constexpr uint64_t bits() noexcept + { + return build_bits_helper<0, Values...>::value; + } + static constexpr uint64_t low_mask(std::size_t k) noexcept + { + return (k >= 64u) ? ~uint64_t(0) : ((uint64_t(1) << k) - 1u); + } + static constexpr uint64_t truncated_mask() noexcept + { + return bits() & low_mask(size); + } + static constexpr std::size_t countr_zero_impl(uint64_t v, std::size_t n) noexcept + { + return (n == 0 || (v & 1u) != 0u) ? 0u : (1u + countr_zero_impl(v >> 1, n - 1)); + } + static constexpr std::size_t countr_one_impl(uint64_t v, std::size_t n) noexcept + { + return (n == 0 || (v & 1u) == 0u) ? 0u : (1u + countr_one_impl(v >> 1, n - 1)); + } + static constexpr std::size_t countl_zero_impl(uint64_t v, std::size_t n) noexcept + { + return (n == 0) ? 0u : ((((v >> (n - 1)) & 1u) != 0u) ? 0u : (1u + countl_zero_impl(v, n - 1))); + } + static constexpr std::size_t countl_one_impl(uint64_t v, std::size_t n) noexcept + { + return (n == 0) ? 0u : ((((v >> (n - 1)) & 1u) == 0u) ? 0u : (1u + countl_one_impl(v, n - 1))); + } + }; /** * @brief batch of integral constants * diff --git a/test/test_batch_bool.cpp b/test/test_batch_bool.cpp index b262e65a6..f9321daf7 100644 --- a/test/test_batch_bool.cpp +++ b/test/test_batch_bool.cpp @@ -12,7 +12,9 @@ #include "xsimd/xsimd.hpp" #ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE +#include #include +#include #include #include "test_utils.hpp" @@ -20,6 +22,83 @@ namespace xsimd { + namespace test_detail + { + template + struct ct_mask_arch + { + static constexpr bool supported() noexcept { return true; } + static constexpr bool available() noexcept { return true; } + static constexpr std::size_t alignment() noexcept { return 0; } + static constexpr bool requires_alignment() noexcept { return false; } + static constexpr char const* name() noexcept { return "ct_mask_arch"; } + }; + + template + struct ct_mask_register + { + std::array data {}; + }; + + struct mask_all_false + { + static constexpr bool get(std::size_t, std::size_t) { return false; } + }; + + struct mask_all_true + { + static constexpr bool get(std::size_t, std::size_t) { return true; } + }; + + struct mask_prefix1 + { + static constexpr bool get(std::size_t i, std::size_t) { return i < 1; } + }; + + struct mask_suffix1 + { + static constexpr bool get(std::size_t i, std::size_t n) { return i >= (n - 1); } + }; + + struct mask_ends + { + static constexpr bool get(std::size_t i, std::size_t n) + { + return (i < 1) || (i >= (n - 1)); + } + }; + + struct mask_interleaved + { + static constexpr bool get(std::size_t i, std::size_t) { return (i % 2) == 0; } + }; + + template + struct alternating_numeric + { + static constexpr T get(std::size_t i, std::size_t) + { + return (i % 2) ? T(2) : T(1); + } + }; + } + + namespace types + { + template + struct simd_register> + { + using register_type = test_detail::ct_mask_register; + register_type data; + constexpr operator register_type() const noexcept { return data; } + }; + + template + struct has_simd_register> : std::true_type + { + }; + } + int popcount(int v) { // from https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan @@ -154,16 +233,188 @@ namespace xsimd } -template +template struct batch_bool_test { - using batch_type = B; - using value_type = typename B::value_type; - static constexpr size_t size = B::size; - using batch_bool_type = typename B::batch_bool_type; + using batch_type = T; + using value_type = typename T::value_type; + static constexpr size_t size = T::size; + using batch_bool_type = typename T::batch_bool_type; using array_type = std::array; using bool_array_type = std::array; + // Compile-time check helpers for batch_bool_constant masks + template + struct xsimd_ct_mask_checker; + + // Small masks: safe to compare numeric masks at compile time + template + struct xsimd_ct_mask_checker::type> + { + static constexpr std::size_t sum_indices(uint64_t bits, std::size_t index, std::size_t remaining) + { + return remaining == 0 + ? 0u + : ((bits & 1u ? index : 0u) + sum_indices(bits >> 1, index + 1, remaining - 1)); + } + + static constexpr uint32_t low_mask_bits(std::size_t width) + { + return width == 0 ? 0u : (static_cast(1u << width) - 1u); + } + + template + struct splice_checker + { + static void run() + { + } + }; + + template + struct splice_checker + { + static void run() + { + constexpr std::size_t begin = 1; + constexpr std::size_t end = (Mask::size > 3 ? 3 : Mask::size); + constexpr std::size_t length = (end > begin) ? (end - begin) : 0; + using slice_arch = xsimd::test_detail::ct_mask_arch; + constexpr auto slice = Mask::template splice(); + constexpr uint32_t src_mask = static_cast(Mask::mask()); + constexpr uint32_t expected = (src_mask >> begin) & low_mask_bits(length); + static_assert(static_cast(slice.mask()) == expected, "splice mask expected"); + constexpr uint32_t slice_bits = static_cast(slice.mask()); + constexpr uint32_t shifted_source = src_mask >> begin; + static_assert((length == 0) || ((slice_bits & 1u) == (shifted_source & 1u)), "slice first bit matches"); + static_assert((length <= 1) || (((slice_bits >> (length - 1)) & 1u) == ((shifted_source >> (length - 1)) & 1u)), + "slice last bit matches"); + } + }; + + template + struct half_checker + { + static void run() + { + } + }; + + template + struct half_checker + { + static void run() + { + constexpr std::size_t total = Mask::size; + constexpr std::size_t mid = total / 2; + using lower_arch = xsimd::test_detail::ct_mask_arch; + using upper_arch = xsimd::test_detail::ct_mask_arch; + constexpr auto lower = Mask::template lower_half(); + constexpr auto upper = Mask::template upper_half(); + constexpr uint32_t source_mask = static_cast(Mask::mask()); + static_assert(static_cast(lower.mask()) == (source_mask & low_mask_bits(mid)), + "lower_half mask matches"); + static_assert(static_cast(upper.mask()) == ((source_mask >> mid) & low_mask_bits(total - mid)), + "upper_half mask matches"); + constexpr auto lower_splice = Mask::template splice(); + constexpr auto upper_splice = Mask::template splice(); + static_assert(lower.mask() == lower_splice.mask(), "lower_half equals splice"); + static_assert(upper.mask() == upper_splice.mask(), "upper_half equals splice"); + constexpr uint32_t lower_bits = static_cast(lower.mask()); + constexpr uint32_t upper_bits = static_cast(upper.mask()); + constexpr std::size_t upper_size = decltype(upper)::size; + static_assert((mid == 0) || ((lower_bits & 1u) == (source_mask & 1u)), "lower first element"); + static_assert((mid <= 1) || (((lower_bits >> (mid - 1)) & 1u) == ((source_mask >> (mid - 1)) & 1u)), + "lower last element"); + static_assert((upper_size == 0) || ((upper_bits & 1u) == ((source_mask >> mid) & 1u)), + "upper first element"); + static_assert((upper_size <= 1) || (((upper_bits >> (upper_size - 1)) & 1u) == ((source_mask >> (total - 1)) & 1u)), + "upper last element"); + } + }; + + static void run() + { + using value_type = typename B::value_type; + using arch_type = typename B::arch_type; + constexpr auto m_zero = xsimd::make_batch_bool_constant(); + constexpr auto m_one = xsimd::make_batch_bool_constant(); + constexpr auto m_prefix = xsimd::make_batch_bool_constant(); + constexpr auto m_suffix = xsimd::make_batch_bool_constant(); + constexpr auto m_ends = xsimd::make_batch_bool_constant(); + constexpr auto m_interleaved = xsimd::make_batch_bool_constant(); + + static_assert((m_zero | m_one).mask() == m_one.mask(), "0|1 == 1"); + static_assert((m_zero & m_one).mask() == m_zero.mask(), "0&1 == 0"); + static_assert((m_zero ^ m_zero).mask() == m_zero.mask(), "0^0 == 0"); + static_assert((m_one ^ m_one).mask() == m_zero.mask(), "1^1 == 0"); + + static_assert((!m_zero).mask() == m_one.mask(), "!0 == 1"); + static_assert((~m_zero).mask() == m_one.mask(), "~0 == 1"); + static_assert((!m_one).mask() == m_zero.mask(), "!1 == 0"); + static_assert((~m_one).mask() == m_zero.mask(), "~1 == 0"); + + static_assert(((m_prefix && m_suffix).mask()) == (m_prefix & m_suffix).mask(), "&& consistent"); + static_assert(((m_prefix || m_suffix).mask()) == (m_prefix | m_suffix).mask(), "|| consistent"); + + static_assert((m_prefix | m_suffix).mask() == m_ends.mask(), "prefix|suffix == ends"); + static_assert(B::size == 1 || (m_prefix & m_suffix).mask() == m_zero.mask(), "prefix&suffix == 0 when size>1"); + + static_assert(m_zero.none(), "zero mask none"); + static_assert(!m_zero.any(), "zero mask any"); + static_assert(!m_zero.all(), "zero mask all"); + static_assert(m_zero.countr_zero() == B::size, "zero mask trailing zeros"); + static_assert(m_zero.countl_zero() == B::size, "zero mask leading zeros"); + + static_assert(m_one.all(), "all mask all"); + static_assert(m_one.any(), "all mask any"); + static_assert(!m_one.none(), "all mask none"); + static_assert(m_one.countr_zero() == 0, "all mask trailing zeros"); + static_assert(m_one.countl_zero() == 0, "all mask leading zeros"); + + constexpr auto prefix_bits = static_cast(m_prefix.mask()); + constexpr auto suffix_bits = static_cast(m_suffix.mask()); + constexpr auto ends_bits_mask = static_cast(m_ends.mask()); + + static_assert((B::size == 0) || ((prefix_bits & 1u) != 0u), "prefix first element set"); + static_assert((B::size <= 1) || ((prefix_bits & (1u << 1)) == 0u), "prefix second element cleared"); + + static_assert((B::size == 0) || (((suffix_bits >> (B::size - 1)) & 1u) != 0u), "suffix last element set"); + static_assert((B::size <= 1) || ((suffix_bits & 1u) == 0u), "suffix first element cleared"); + + static_assert((B::size == 0) || ((ends_bits_mask & 1u) != 0u), "ends first element set"); + static_assert((B::size == 0) || (((ends_bits_mask >> (B::size - 1)) & 1u) != 0u), "ends last element set"); + static_assert((B::size <= 2) || (((ends_bits_mask >> 1) & 1u) == 0u), "ends interior element cleared"); + + static_assert(std::is_same::value, + "as_batch_bool type"); + static_assert(std::is_same(m_prefix)), typename B::batch_bool_type>::value, + "conversion operator type"); + + // splice API is validated indirectly via arch-specific masked implementations. + + constexpr std::size_t prefix_zero = m_prefix.countr_zero(); + constexpr std::size_t prefix_one = m_prefix.countr_one(); + static_assert(prefix_zero == 0, "prefix mask zero leading zeros from LSB"); + static_assert((B::size == 0 ? prefix_one == 0 : prefix_one == 1), "prefix mask trailing ones count"); + + constexpr std::size_t suffix_zero = m_suffix.countl_zero(); + constexpr std::size_t suffix_one = m_suffix.countl_one(); + static_assert(suffix_zero == 0, "suffix mask leading zeros count"); + static_assert((B::size == 0 ? suffix_one == 0 : suffix_one == 1), "suffix mask trailing ones count"); + + splice_checker 1)>::run(); + half_checker 0 && (B::size % 2 == 0))>::run(); + } + }; + + // Large masks: avoid calling mask() in constant expressions + template + struct xsimd_ct_mask_checker 31)>::type> + { + static void run() { } + }; + array_type lhs; array_type rhs; bool_array_type all_true; @@ -521,6 +772,11 @@ struct batch_bool_test } } + void test_mask_compile_time() const + { + xsimd_ct_mask_checker::run(); + } + private: batch_type batch_lhs() const { @@ -554,5 +810,7 @@ TEST_CASE_TEMPLATE("[xsimd batch bool]", B, BATCH_TYPES) SUBCASE("count") { Test.test_count(); } SUBCASE("eq neq") { Test.test_comparison(); } + + SUBCASE("mask utils (compile-time)") { Test.test_mask_compile_time(); } } #endif diff --git a/test/test_load_store.cpp b/test/test_load_store.cpp index 449c41e85..9fa7dbff8 100644 --- a/test/test_load_store.cpp +++ b/test/test_load_store.cpp @@ -12,7 +12,10 @@ #include "xsimd/xsimd.hpp" #ifndef XSIMD_NO_SUPPORTED_ARCHITECTURE +#include +#include #include +#include #include "test_utils.hpp" @@ -22,6 +25,7 @@ struct load_store_test using batch_type = B; using value_type = typename B::value_type; using index_type = typename xsimd::as_integer_t; + using batch_bool_type = typename batch_type::batch_bool_type; template using allocator = xsimd::default_allocator; static constexpr size_t size = B::size; @@ -43,6 +47,65 @@ struct load_store_test using double_vector_type = std::vector>; #endif + struct mask_none + { + static constexpr bool get(std::size_t, std::size_t) noexcept { return false; } + }; + + struct mask_first + { + static constexpr bool get(std::size_t index, std::size_t) noexcept { return index == 0; } + }; + + struct mask_first_half + { + static constexpr bool get(std::size_t index, std::size_t size) noexcept { return index < (size / 2); } + }; + + struct mask_last_half + { + static constexpr bool get(std::size_t index, std::size_t size) noexcept { return index >= (size / 2); } + }; + + struct mask_first_n + { + static constexpr bool get(std::size_t index, std::size_t size) noexcept + { + return index < (size > 2 ? size / 3 : std::size_t(1)); + } + }; + + struct mask_last_n + { + static constexpr bool get(std::size_t index, std::size_t size) noexcept + { + return index >= size - (size > 2 ? size / 3 : std::size_t(1)); + } + }; + + struct mask_even + { + static constexpr bool get(std::size_t index, std::size_t) noexcept { return (index % 2) == 0; } + }; + + struct mask_odd + { + static constexpr bool get(std::size_t index, std::size_t) noexcept { return (index % 2) == 1; } + }; + + struct mask_pseudo_random + { + static constexpr bool get(std::size_t index, std::size_t size) noexcept + { + return ((index * 7) + 3) % size < (size > 2 ? size / 3 : std::size_t(1)); + } + }; + + struct mask_all + { + static constexpr bool get(std::size_t, std::size_t) noexcept { return true; } + }; + int8_vector_type i8_vec; uint8_vector_type ui8_vec; int16_vector_type i16_vec; @@ -162,6 +225,73 @@ struct load_store_test #endif } + void test_masked() + { + using arch = typename B::arch_type; + using test_batch_type = xsimd::batch; + constexpr std::size_t test_size = test_batch_type::size; + using int_allocator_type = xsimd::default_allocator; + + std::vector source(test_size); + for (std::size_t i = 0; i < test_size; ++i) + { + source[i] = static_cast(i * 17 - 9); + } + + struct cross_type_mask + { + static constexpr bool get(std::size_t index, std::size_t size) noexcept + { + return ((index & std::size_t(1)) != 0) || (size == std::size_t(1)) || ((index == size - std::size_t(1)) && ((size % std::size_t(2)) == 0)); + } + }; + auto mask = xsimd::make_batch_bool_constant(); + + std::array expected_load; + expected_load.fill(0.f); + for (std::size_t i = 0; i < test_size; ++i) + { + if (cross_type_mask::get(i, test_size)) + { + expected_load[i] = static_cast(source[i]); + } + } + + auto loaded_aligned = test_batch_type::load(source.data(), mask, xsimd::aligned_mode()); + INFO("cross-type masked load aligned"); + CHECK_BATCH_EQ(loaded_aligned, expected_load); + + auto loaded_unaligned = test_batch_type::load(source.data(), mask, xsimd::unaligned_mode()); + INFO("cross-type masked load unaligned"); + CHECK_BATCH_EQ(loaded_unaligned, expected_load); + + std::array values; + for (std::size_t i = 0; i < test_size; ++i) + { + values[i] = static_cast(static_cast(i) * 2 - 7) / 3.f; + } + auto value_batch = test_batch_type::load_unaligned(values.data()); + + std::vector destination(test_size, -19); + std::vector expected_store(test_size, -19); + for (std::size_t i = 0; i < test_size; ++i) + { + if (cross_type_mask::get(i, test_size)) + { + expected_store[i] = static_cast(values[i]); + } + } + + value_batch.store(destination.data(), mask, xsimd::aligned_mode()); + INFO("cross-type masked store aligned"); + CHECK_VECTOR_EQ(destination, expected_store); + + std::fill(destination.begin(), destination.end(), -19); + value_batch.store(destination.data(), mask, xsimd::unaligned_mode()); + INFO("cross-type masked store unaligned"); + CHECK_VECTOR_EQ(destination, expected_store); + } + private: #ifdef XSIMD_WITH_SSE2 struct test_load_as_return_type @@ -193,6 +323,122 @@ struct load_store_test b = xsimd::load_as(v.data(), xsimd::aligned_mode()); INFO(name, " aligned (load_as)"); CHECK_BATCH_EQ(b, expected); + + run_mask_tests(v, name, b, expected, std::is_same {}); + } + + template + void run_mask_tests(const V& v, const std::string& name, batch_type& b, const array_type& expected, std::true_type) + { + run_load_mask_pattern(v, name, b, expected, " masked none"); + run_load_mask_pattern(v, name, b, expected, " masked first element"); + run_load_mask_pattern(v, name, b, expected, " masked first half"); + run_load_mask_pattern(v, name, b, expected, " masked last half"); + run_load_mask_pattern(v, name, b, expected, " masked first N"); + run_load_mask_pattern(v, name, b, expected, " masked last N"); + run_load_mask_pattern(v, name, b, expected, " masked even elements"); + run_load_mask_pattern(v, name, b, expected, " masked odd elements"); + run_load_mask_pattern(v, name, b, expected, " masked pseudo random"); + run_load_mask_pattern(v, name, b, expected, " masked all elements"); + } + + template + void run_mask_tests(const V&, const std::string&, batch_type&, const array_type&, std::false_type) + { + } + + template + void run_load_mask_pattern(const V& v, const std::string& name, batch_type& b, const array_type& expected, const std::string& label) + { + constexpr auto mask = xsimd::make_batch_bool_constant(); + array_type expected_masked { 0 }; + + for (std::size_t i = 0; i < size; ++i) + { + const bool active = Generator::get(i, size); + expected_masked[i] = active ? expected[i] : value_type(); + } + + b = xsimd::load(v.data(), mask, xsimd::aligned_mode()); + INFO(name, label + " aligned"); + CHECK_BATCH_EQ(b, expected_masked); + b = xsimd::load(v.data(), mask, xsimd::unaligned_mode()); + INFO(name, label + " unaligned"); + CHECK_BATCH_EQ(b, expected_masked); + } + + template + void run_store_mask_pattern(const V& v, const std::string& name, batch_type& b, V& res, V& expected_masked, const std::string& label) + { + auto mask = xsimd::make_batch_bool_constant(); + for (std::size_t i = 0; i < size; ++i) + { + expected_masked[i] = Generator::get(i, size) ? v[i] : value_type(); + } + std::fill(res.begin(), res.end(), value_type()); + b.store(res.data(), mask, xsimd::aligned_mode()); + INFO(name, label + " aligned"); + CHECK_VECTOR_EQ(res, expected_masked); + std::fill(res.begin(), res.end(), value_type()); + b.store(res.data(), mask, xsimd::unaligned_mode()); + INFO(name, label + " unaligned"); + CHECK_VECTOR_EQ(res, expected_masked); + } + + template + void run_store_mask_tests(const V& v, const std::string& name, batch_type& b, V& res, V& expected_masked, std::true_type) + { + run_store_mask_pattern(v, name, b, res, expected_masked, " masked first element"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked first half"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked last half"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked first N"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked last N"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked even elements"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked odd elements"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked pseudo random"); + run_store_mask_pattern(v, name, b, res, expected_masked, " masked all elements"); + } + + template + void run_store_mask_tests(const V&, const std::string&, batch_type&, V&, V&, std::false_type) + { + } + + template + void run_store_mask_section(const V& v, + const std::string& name, + batch_type& b, + V& res, + V& expected_masked, + std::true_type) + { + static constexpr auto sentinel = static_cast(37); + V sentinel_expected(size, sentinel); + + auto zero_mask = xsimd::make_batch_bool_constant(); + std::fill(res.begin(), res.end(), sentinel); + b.store(res.data(), zero_mask, xsimd::aligned_mode()); + INFO(name, " masked none aligned store"); + CHECK_VECTOR_EQ(res, sentinel_expected); + + V scratch(res.size() + size); + std::fill(scratch.begin(), scratch.end(), sentinel); + auto* scratch_ptr = scratch.data() + 1; + b.store(scratch_ptr, zero_mask, xsimd::unaligned_mode()); + INFO(name, " masked none unaligned store"); + + V scratch_slice(res.size()); + std::copy(scratch_ptr, scratch_ptr + scratch_slice.size(), scratch_slice.begin()); + CHECK_VECTOR_EQ(scratch_slice, sentinel_expected); + CHECK(std::all_of(scratch.begin(), scratch.end(), [](const value_type v) + { return v == sentinel; })); + + run_store_mask_tests(v, name, b, res, expected_masked, std::true_type {}); + } + + template + void run_store_mask_section(const V&, const std::string&, batch_type&, V&, V&, std::false_type) + { } struct test_load_char @@ -227,6 +473,10 @@ struct load_store_test xsimd::store_as(res.data(), b, xsimd::aligned_mode()); INFO(name, " aligned (store_as)"); CHECK_VECTOR_EQ(res, v); + + V expected_masked(size); + + run_store_mask_section(v, name, b, res, expected_masked, std::is_same {}); } template @@ -290,6 +540,9 @@ struct load_store_test } }; +template +constexpr size_t load_store_test::size; + TEST_CASE_TEMPLATE("[load store]", B, BATCH_TYPES) { load_store_test Test; @@ -300,5 +553,7 @@ TEST_CASE_TEMPLATE("[load store]", B, BATCH_TYPES) SUBCASE("gather") { Test.test_gather(); } SUBCASE("scatter") { Test.test_scatter(); } + + SUBCASE("masked") { Test.test_masked(); } } #endif