Skip to content

Commit 06e56c3

Browse files
committed
1. Adds new masked API compile time masks (store_masked and load_masked)
2. General use case optimization 3. New tests 4. x86 kernels
1 parent ac8a93f commit 06e56c3

File tree

13 files changed

+2220
-20
lines changed

13 files changed

+2220
-20
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
#ifndef XSIMD_COMMON_MEMORY_HPP
1313
#define XSIMD_COMMON_MEMORY_HPP
1414

15+
#include "../../types/xsimd_batch_constant.hpp"
16+
#include "./xsimd_common_details.hpp"
1517
#include <algorithm>
18+
#include <array>
1619
#include <complex>
1720
#include <stdexcept>
1821

19-
#include "../../types/xsimd_batch_constant.hpp"
20-
#include "./xsimd_common_details.hpp"
21-
2222
namespace xsimd
2323
{
2424
template <typename T, class A, T... Values>
@@ -348,6 +348,162 @@ namespace xsimd
348348
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349349
}
350350

351+
template <class A, class T>
352+
XSIMD_INLINE batch<T, A> load(T const* mem, aligned_mode, requires_arch<A>) noexcept
353+
{
354+
return load_aligned<A>(mem, convert<T> {}, A {});
355+
}
356+
357+
template <class A, class T>
358+
XSIMD_INLINE batch<T, A> load(T const* mem, unaligned_mode, requires_arch<A>) noexcept
359+
{
360+
return load_unaligned<A>(mem, convert<T> {}, A {});
361+
}
362+
363+
template <class A, class T_in, class T_out, bool... Values, class alignment>
364+
XSIMD_INLINE batch<T_out, A>
365+
load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...>, convert<T_out>, alignment, requires_arch<common>) noexcept
366+
{
367+
constexpr std::size_t size = batch<T_out, A>::size;
368+
alignas(A::alignment()) std::array<T_out, size> buffer {};
369+
constexpr std::array<bool, size> mask { Values... };
370+
371+
for (std::size_t i = 0; i < size; ++i)
372+
buffer[i] = mask[i] ? static_cast<T_out>(mem[i]) : T_out(0);
373+
374+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
375+
}
376+
377+
template <class A, class T_in, class T_out, bool... Values, class alignment>
378+
XSIMD_INLINE void
379+
store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept
380+
{
381+
constexpr std::size_t size = batch<T_in, A>::size;
382+
constexpr std::array<bool, size> mask { Values... };
383+
384+
for (std::size_t i = 0; i < size; ++i)
385+
if (mask[i])
386+
{
387+
mem[i] = static_cast<T_out>(src.get(i));
388+
}
389+
}
390+
391+
template <class A, bool... Values, class Mode>
392+
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem,
393+
batch_bool_constant<int32_t, A, Values...>,
394+
convert<int32_t>,
395+
Mode,
396+
requires_arch<A>) noexcept
397+
{
398+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem),
399+
batch_bool_constant<float, A, Values...> {},
400+
convert<float> {},
401+
Mode {},
402+
A {});
403+
return bitwise_cast<int32_t>(f);
404+
}
405+
406+
template <class A, bool... Values, class Mode>
407+
XSIMD_INLINE batch<uint32_t, A> load_masked(uint32_t const* mem,
408+
batch_bool_constant<uint32_t, A, Values...>,
409+
convert<uint32_t>,
410+
Mode,
411+
requires_arch<A>) noexcept
412+
{
413+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem),
414+
batch_bool_constant<float, A, Values...> {},
415+
convert<float> {},
416+
Mode {},
417+
A {});
418+
return bitwise_cast<uint32_t>(f);
419+
}
420+
421+
template <class A, bool... Values, class Mode>
422+
XSIMD_INLINE batch<int64_t, A> load_masked(int64_t const* mem,
423+
batch_bool_constant<int64_t, A, Values...>,
424+
convert<int64_t>,
425+
Mode,
426+
requires_arch<A>) noexcept
427+
{
428+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem),
429+
batch_bool_constant<double, A, Values...> {},
430+
convert<double> {},
431+
Mode {},
432+
A {});
433+
return bitwise_cast<int64_t>(d);
434+
}
435+
436+
template <class A, bool... Values, class Mode>
437+
XSIMD_INLINE batch<uint64_t, A> load_masked(uint64_t const* mem,
438+
batch_bool_constant<uint64_t, A, Values...>,
439+
convert<uint64_t>,
440+
Mode,
441+
requires_arch<A>) noexcept
442+
{
443+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem),
444+
batch_bool_constant<double, A, Values...> {},
445+
convert<double> {},
446+
Mode {},
447+
A {});
448+
return bitwise_cast<uint64_t>(d);
449+
}
450+
451+
template <class A, bool... Values, class Mode>
452+
XSIMD_INLINE void store_masked(int32_t* mem,
453+
batch<int32_t, A> const& src,
454+
batch_bool_constant<int32_t, A, Values...>,
455+
Mode,
456+
requires_arch<A>) noexcept
457+
{
458+
store_masked<A>(reinterpret_cast<float*>(mem),
459+
bitwise_cast<float>(src),
460+
batch_bool_constant<float, A, Values...> {},
461+
Mode {},
462+
A {});
463+
}
464+
465+
template <class A, bool... Values, class Mode>
466+
XSIMD_INLINE void store_masked(uint32_t* mem,
467+
batch<uint32_t, A> const& src,
468+
batch_bool_constant<uint32_t, A, Values...>,
469+
Mode,
470+
requires_arch<A>) noexcept
471+
{
472+
store_masked<A>(reinterpret_cast<float*>(mem),
473+
bitwise_cast<float>(src),
474+
batch_bool_constant<float, A, Values...> {},
475+
Mode {},
476+
A {});
477+
}
478+
479+
template <class A, bool... Values, class Mode>
480+
XSIMD_INLINE void store_masked(int64_t* mem,
481+
batch<int64_t, A> const& src,
482+
batch_bool_constant<int64_t, A, Values...>,
483+
Mode,
484+
requires_arch<A>) noexcept
485+
{
486+
store_masked<A>(reinterpret_cast<double*>(mem),
487+
bitwise_cast<double>(src),
488+
batch_bool_constant<double, A, Values...> {},
489+
Mode {},
490+
A {});
491+
}
492+
493+
template <class A, bool... Values, class Mode>
494+
XSIMD_INLINE void store_masked(uint64_t* mem,
495+
batch<uint64_t, A> const& src,
496+
batch_bool_constant<uint64_t, A, Values...>,
497+
Mode,
498+
requires_arch<A>) noexcept
499+
{
500+
store_masked<A>(reinterpret_cast<double*>(mem),
501+
bitwise_cast<double>(src),
502+
batch_bool_constant<double, A, Values...> {},
503+
Mode {},
504+
A {});
505+
}
506+
351507
// rotate_right
352508
template <size_t N, class A, class T>
353509
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

0 commit comments

Comments
 (0)