Skip to content

Commit f814bcd

Browse files
committed
WIP
1 parent 2c3875e commit f814bcd

File tree

2 files changed

+49
-34
lines changed

2 files changed

+49
-34
lines changed

include/xsimd/arch/xsimd_avx512dq.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,39 @@ namespace xsimd
2121
{
2222
using namespace types;
2323

24+
// load_masked
25+
template <class A, bool... Values, class Mode>
26+
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem,
27+
batch_bool_constant<int32_t, A, Values...> mask,
28+
convert<int32_t>,
29+
Mode,
30+
requires_arch<avx512dq>) noexcept
31+
{
32+
XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8)
33+
{
34+
constexpr auto mhi = mask.template upper_half<avx2>();
35+
const auto hi = load_masked<avx2>(mem + 8, mhi, convert<int32_t> {}, Mode {}, avx2 {});
36+
return _mm512_inserti32x8(_mm512_setzero_si512(), hi, 1);
37+
}
38+
return load_masked<A>(mem, mask, convert<int32_t>, Mode, avx512f);
39+
}
40+
41+
template <class A, bool... Values, class Mode>
42+
XSIMD_INLINE batch<float, A> load_masked(float const* mem,
43+
batch_bool_constant<float, A, Values...> mask,
44+
convert<float>,
45+
Mode,
46+
requires_arch<avx512dq>) noexcept
47+
{
48+
XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8)
49+
{
50+
constexpr auto mhi = mask.template upper_half<avx2>();
51+
const auto hi = load_masked<avx2>(mem + 8, mhi, convert<float> {}, Mode {}, avx2 {});
52+
return _mm512_insertf32x8(_mm512_setzero_ps(), hi, 1);
53+
}
54+
return load_masked<A>(mem, mask, convert<float>, Mode, avx512f);
55+
}
56+
2457
// bitwise_and
2558
template <class A>
2659
XSIMD_INLINE batch<float, A> bitwise_and(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx512dq>) noexcept

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ namespace xsimd
246246
}
247247
}
248248

249-
// load_masked (integers)
249+
// load_masked
250250
template <class A, bool... Values, class Mode>
251251
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem,
252252
batch_bool_constant<int32_t, A, Values...> mask,
@@ -268,22 +268,15 @@ namespace xsimd
268268
const auto lo = load_masked<avx2>(mem, mlo, convert<int32_t> {}, Mode {}, avx2 {});
269269
return _mm512_zextsi256_si512(lo);
270270
}
271-
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8)
272-
{
273-
constexpr auto mhi = mask.template upper_half<avx2>();
274-
const auto hi = load_masked<avx2>(mem + 8, mhi, convert<int32_t> {}, Mode {}, avx2 {});
275-
return _mm512_inserti32x8(_mm512_setzero_si512(), hi, 1);
276-
}
277271
else
278272
{
279-
constexpr auto k = static_cast<__mmask16>(batch_bool_constant<int32_t, A, Values...>::mask());
280273
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
281274
{
282-
return _mm512_maskz_load_epi32(k, mem);
275+
return _mm512_maskz_load_epi32(mask.mask(), mem);
283276
}
284277
else
285278
{
286-
return _mm512_maskz_loadu_epi32(k, mem);
279+
return _mm512_maskz_loadu_epi32(mask.mask(), mem);
287280
}
288281
}
289282
}
@@ -330,14 +323,13 @@ namespace xsimd
330323
}
331324
else
332325
{
333-
constexpr auto k = static_cast<__mmask8>(batch_bool_constant<int64_t, A, Values...>::mask());
334326
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
335327
{
336-
return _mm512_maskz_load_epi64(k, mem);
328+
return _mm512_maskz_load_epi64(mask.mask(), mem);
337329
}
338330
else
339331
{
340-
return _mm512_maskz_loadu_epi64(k, mem);
332+
return _mm512_maskz_loadu_epi64(mask.mask(), mem);
341333
}
342334
}
343335
}
@@ -1539,12 +1531,6 @@ namespace xsimd
15391531
const auto lo = load_masked<avx2>(mem, mlo, convert<float> {}, Mode {}, avx2 {});
15401532
return _mm512_zextps256_ps512(lo);
15411533
}
1542-
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 8)
1543-
{
1544-
constexpr auto mhi = mask.template upper_half<avx2>();
1545-
const auto hi = load_masked<avx2>(mem + 8, mhi, convert<float> {}, Mode {}, avx2 {});
1546-
return _mm512_insertf32x8(_mm512_setzero_ps(), hi, 1);
1547-
}
15481534
else
15491535
{
15501536
// Last resort: 512-bit masked load
@@ -1590,11 +1576,11 @@ namespace xsimd
15901576
{
15911577
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
15921578
{
1593-
return _mm512_maskz_load_pd(static_cast<__mmask8>(batch_bool_constant<double, A, Values...>::mask()), mem);
1579+
return _mm512_maskz_load_pd(mask.mask(), mem);
15941580
}
15951581
else
15961582
{
1597-
return _mm512_maskz_loadu_pd(static_cast<__mmask8>(batch_bool_constant<double, A, Values...>::mask()), mem);
1583+
return _mm512_maskz_loadu_pd(mask.mask(), mem);
15981584
}
15991585
}
16001586
}
@@ -1629,14 +1615,13 @@ namespace xsimd
16291615
}
16301616
else
16311617
{
1632-
auto k = static_cast<__mmask16>(batch_bool_constant<float, A, Values...>::mask());
16331618
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
16341619
{
1635-
_mm512_mask_store_ps(mem, k, src);
1620+
_mm512_mask_store_ps(mem, mask.mask(), src);
16361621
}
16371622
else
16381623
{
1639-
_mm512_mask_storeu_ps(mem, k, src);
1624+
_mm512_mask_storeu_ps(mem, mask.mask(), src);
16401625
}
16411626
}
16421627
}
@@ -1670,19 +1655,18 @@ namespace xsimd
16701655
}
16711656
else
16721657
{
1673-
constexpr auto k = static_cast<__mmask8>(batch_bool_constant<double, A, Values...>::mask());
16741658
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
16751659
{
1676-
_mm512_mask_store_pd(mem, k, src);
1660+
_mm512_mask_store_pd(mem, mask.mask(), src);
16771661
}
16781662
else
16791663
{
1680-
_mm512_mask_storeu_pd(mem, k, src);
1664+
_mm512_mask_storeu_pd(mem, mask.mask(), src);
16811665
}
16821666
}
16831667
}
16841668

1685-
// store_masked (integers)
1669+
// store_masked
16861670
template <class A, bool... Values, class Mode>
16871671
XSIMD_INLINE void store_masked(int32_t* mem,
16881672
batch<int32_t, A> const& src,
@@ -1712,14 +1696,13 @@ namespace xsimd
17121696
}
17131697
else
17141698
{
1715-
auto k = static_cast<__mmask16>(batch_bool_constant<int32_t, A, Values...>::mask());
17161699
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
17171700
{
1718-
_mm512_mask_store_epi32(mem, k, src);
1701+
_mm512_mask_store_epi32(mem, mask.mask(), src);
17191702
}
17201703
else
17211704
{
1722-
_mm512_mask_storeu_epi32(mem, k, src);
1705+
_mm512_mask_storeu_epi32(mem, mask.mask(), src);
17231706
}
17241707
}
17251708
}
@@ -1766,14 +1749,13 @@ namespace xsimd
17661749
}
17671750
else
17681751
{
1769-
constexpr auto k = static_cast<__mmask8>(batch_bool_constant<int64_t, A, Values...>::mask());
17701752
XSIMD_IF_CONSTEXPR(std::is_same<Mode, aligned_mode>::value)
17711753
{
1772-
_mm512_mask_store_epi64(mem, k, src);
1754+
_mm512_mask_store_epi64(mem, mask.mask(), src);
17731755
}
17741756
else
17751757
{
1776-
_mm512_mask_storeu_epi64(mem, k, src);
1758+
_mm512_mask_storeu_epi64(mem, mask.mask(), src);
17771759
}
17781760
}
17791761
}

0 commit comments

Comments
 (0)