Skip to content

Commit 6ccdfec

Browse files
WIP
1 parent 3ce77f1 commit 6ccdfec

File tree

1 file changed

+38
-43
lines changed

1 file changed

+38
-43
lines changed

include/xsimd/arch/xsimd_neon.hpp

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2916,6 +2916,24 @@ namespace xsimd
29162916
return vreinterpretq_s64_u64(swizzle(vreinterpretq_u64_s64(self), mask, A {}));
29172917
}
29182918

2919+
namespace detail
2920+
{
2921+
template <uint32_t Va, uint32_t Vb>
2922+
uint8x8_t make_mask()
2923+
{
2924+
return {
2925+
static_cast<uint8_t>((Va % 2) * 4 + 0),
2926+
static_cast<uint8_t>((Va % 2) * 4 + 1),
2927+
static_cast<uint8_t>((Va % 2) * 4 + 2),
2928+
static_cast<uint8_t>((Va % 2) * 4 + 3),
2929+
static_cast<uint8_t>((Vb % 2) * 4 + 0),
2930+
static_cast<uint8_t>((Vb % 2) * 4 + 1),
2931+
static_cast<uint8_t>((Vb % 2) * 4 + 2),
2932+
static_cast<uint8_t>((Vb % 2) * 4 + 3),
2933+
};
2934+
}
2935+
}
2936+
29192937
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3>
29202938
XSIMD_INLINE batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self,
29212939
batch_constant<uint32_t, A, V0, V1, V2, V3> mask,
@@ -2956,60 +2974,37 @@ namespace xsimd
29562974
XSIMD_IF_CONSTEXPR(V0 < 2 && V1 < 2 && V2 < 2 && V3 < 2)
29572975
{
29582976
uint8x8_t low = vreinterpret_u8_u64(vget_low_u64(vreinterpretq_u64_u32(self)));
2959-
uint8x8_t mask_lo = {
2960-
static_cast<uint8_t>(V0 * 4 + 0),
2961-
static_cast<uint8_t>(V0 * 4 + 1),
2962-
static_cast<uint8_t>(V0 * 4 + 2),
2963-
static_cast<uint8_t>(V0 * 4 + 3),
2964-
static_cast<uint8_t>(V1 * 4 + 0),
2965-
static_cast<uint8_t>(V1 * 4 + 1),
2966-
static_cast<uint8_t>(V1 * 4 + 2),
2967-
static_cast<uint8_t>(V1 * 4 + 3),
2968-
};
2969-
uint8x8_t mask_hi = {
2970-
static_cast<uint8_t>(V2 * 4 + 0),
2971-
static_cast<uint8_t>(V2 * 4 + 1),
2972-
static_cast<uint8_t>(V2 * 4 + 2),
2973-
static_cast<uint8_t>(V2 * 4 + 3),
2974-
static_cast<uint8_t>(V3 * 4 + 0),
2975-
static_cast<uint8_t>(V3 * 4 + 1),
2976-
static_cast<uint8_t>(V3 * 4 + 2),
2977-
static_cast<uint8_t>(V3 * 4 + 3),
2978-
};
2977+
uint8x8_t mask_lo = make_mask<V0, V1>();
2978+
uint8x8_t mask_hi = make_mask<V2, V3>();
29792979
uint8x8_t lo = vtbl1_u8(low, mask_lo);
29802980
uint8x8_t hi = vtbl1_u8(low, mask_hi);
29812981
return vreinterpretq_u32_u8(vcombine_u8(lo, hi));
29822982
}
29832983
XSIMD_IF_CONSTEXPR(V0 >= 2 && V1 >= 2 && V2 >= 2 && V3 >= 2)
29842984
{
29852985
uint8x8_t high = vreinterpret_u8_u64(vget_high_u64(vreinterpretq_u64_u32(self)));
2986-
uint8x8_t mask_lo = {
2987-
static_cast<uint8_t>((V0 - 2) * 4 + 0),
2988-
static_cast<uint8_t>((V0 - 2) * 4 + 1),
2989-
static_cast<uint8_t>((V0 - 2) * 4 + 2),
2990-
static_cast<uint8_t>((V0 - 2) * 4 + 3),
2991-
static_cast<uint8_t>((V1 - 2) * 4 + 0),
2992-
static_cast<uint8_t>((V1 - 2) * 4 + 1),
2993-
static_cast<uint8_t>((V1 - 2) * 4 + 2),
2994-
static_cast<uint8_t>((V1 - 2) * 4 + 3),
2995-
};
2996-
uint8x8_t mask_hi = {
2997-
static_cast<uint8_t>((V2 - 2) * 4 + 0),
2998-
static_cast<uint8_t>((V2 - 2) * 4 + 1),
2999-
static_cast<uint8_t>((V2 - 2) * 4 + 2),
3000-
static_cast<uint8_t>((V2 - 2) * 4 + 3),
3001-
static_cast<uint8_t>((V3 - 2) * 4 + 0),
3002-
static_cast<uint8_t>((V3 - 2) * 4 + 1),
3003-
static_cast<uint8_t>((V3 - 2) * 4 + 2),
3004-
static_cast<uint8_t>((V3 - 2) * 4 + 3),
3005-
};
2986+
uint8x8_t mask_lo = make_mask<V0, V1>();
2987+
uint8x8_t mask_hi = make_mask<V2, V3>();
30062988
uint8x8_t lo = vtbl1_u8(high, mask_lo);
30072989
uint8x8_t hi = vtbl1_u8(high, mask_hi);
30082990
return vreinterpretq_u32_u8(vcombine_u8(lo, hi));
30092991
}
3010-
std::array<uint32_t, 4> data;
3011-
self.store_aligned(data.data());
3012-
return set(batch<uint32_t, A>(), A(), data[V0], data[V1], data[V2], data[V3]);
2992+
2993+
uint8x8_t mask_lo = make_mask<V0, V1>();
2994+
uint8x8_t mask_hi = make_mask<V2, V3>();
2995+
2996+
uint8x8_t low = vreinterpret_u8_u64(vget_low_u64(vreinterpretq_u64_u32(self)));
2997+
uint8x8_t lol = vtbl1_u8(low, mask_lo);
2998+
uint8x8_t loh = vtbl1_u8(low, mask_hi);
2999+
uint32x4_t true_br = vreinterpretq_u32_u8(vcombine_u8(lol, loh));
3000+
3001+
uint8x8_t high = vreinterpret_u8_u64(vget_high_u64(vreinterpretq_u64_u32(self)));
3002+
uint8x8_t hil = vtbl1_u8(high, mask_lo);
3003+
uint8x8_t hih = vtbl1_u8(high, mask_hi);
3004+
uint32x4_t false_br = vreinterpretq_u32_u8(vcombine_u8(hil, hih));
3005+
3006+
batch_bool_constant<uint32_t, A, (V0 < 2), (V1 < 2), (V2 < 2), (V3 < 2)> blend_mask;
3007+
return select(blend_mask, batch<uint32_t, A>(true_br), batch<uint32_t, A>(false_br), A {});
30133008
}
30143009

30153010
template <class A, uint32_t V0, uint32_t V1>

0 commit comments

Comments
 (0)