Skip to content

Commit fa3d5f9

Browse files
More efficient batch_bool::mask() for aarch64
As a complement to #1236
1 parent 79ffbae commit fa3d5f9

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

include/xsimd/arch/xsimd_neon64.hpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,50 @@ namespace xsimd
608608
return vmaxq_f64(lhs, rhs);
609609
}
610610

611+
/********
612+
* mask *
613+
********/
614+
615+
template <class A, class T, detail::enable_sized_t<T, 1> = 0>
616+
XSIMD_INLINE uint64_t mask(batch_bool<T, A> const& self, requires_arch<neon64>) noexcept
617+
{
618+
// From https://github.com/DLTcollab/sse2neon/blob/master/sse2neon.h
619+
// Extract most significant bit
620+
uint8x16_t msbs = vshrq_n_u8(self, 7);
621+
// Position it appropriately
622+
static constexpr int8_t shift_table[16] = { 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 };
623+
int8x16_t shifts = vld1q_s8(shift_table);
624+
uint8x16_t positioned = vshlq_u8(msbs, shifts);
625+
// Horizontal reduction
626+
return vaddv_u8(vget_low_u8(positioned)) | (vaddv_u8(vget_high_u8(positioned)) << 8);
627+
}
628+
629+
template <class A, class T, detail::enable_sized_t<T, 2> = 0>
630+
XSIMD_INLINE uint64_t mask(batch_bool<T, A> const& self, requires_arch<neon64>) noexcept
631+
{
632+
// Extract most significant bit
633+
uint16x8_t msbs = vshrq_n_u16(self, 15);
634+
// Position it appropriately
635+
static constexpr int16_t shift_table[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
636+
int16x8_t shifts = vld1q_s16(shift_table);
637+
uint16x8_t positioned = vshlq_u16(msbs, shifts);
638+
// Horizontal reduction
639+
return vaddvq_u16(positioned);
640+
}
641+
642+
template <class A, class T, detail::enable_sized_t<T, 4> = 0>
643+
XSIMD_INLINE uint64_t mask(batch_bool<T, A> const& self, requires_arch<neon64>) noexcept
644+
{
645+
// Extract most significant bit
646+
uint32x4_t msbs = vshrq_n_u32(self, 31);
647+
// Position it appropriately
648+
static constexpr int32_t shift_table[4] = { 0, 1, 2, 3 };
649+
int32x4_t shifts = vld1q_s32(shift_table);
650+
uint32x4_t positioned = vshlq_u32(msbs, shifts);
651+
// Horizontal reduction
652+
return vaddvq_u32(positioned);
653+
}
654+
611655
/*******
612656
* abs *
613657
*******/

0 commit comments

Comments
 (0)