|
13 | 13 | #define XSIMD_NEON_HPP |
14 | 14 |
|
15 | 15 | #include <algorithm> |
| 16 | +#include <array> |
16 | 17 | #include <complex> |
17 | 18 | #include <tuple> |
18 | 19 | #include <type_traits> |
@@ -2914,6 +2915,114 @@ namespace xsimd |
2914 | 2915 | { |
2915 | 2916 | return vreinterpretq_s64_u64(swizzle(vreinterpretq_u64_s64(self), mask, A {})); |
2916 | 2917 | } |
| 2918 | + |
| 2919 | + namespace detail |
| 2920 | + { |
| 2921 | + template <uint32_t Va, uint32_t Vb> |
| 2922 | + XSIMD_INLINE uint8x8_t make_mask() |
| 2923 | + { |
| 2924 | + uint8x8_t res = { |
| 2925 | + static_cast<uint8_t>((Va % 2) * 4 + 0), |
| 2926 | + static_cast<uint8_t>((Va % 2) * 4 + 1), |
| 2927 | + static_cast<uint8_t>((Va % 2) * 4 + 2), |
| 2928 | + static_cast<uint8_t>((Va % 2) * 4 + 3), |
| 2929 | + static_cast<uint8_t>((Vb % 2) * 4 + 0), |
| 2930 | + static_cast<uint8_t>((Vb % 2) * 4 + 1), |
| 2931 | + static_cast<uint8_t>((Vb % 2) * 4 + 2), |
| 2932 | + static_cast<uint8_t>((Vb % 2) * 4 + 3), |
| 2933 | + }; |
| 2934 | + return res; |
| 2935 | + } |
| 2936 | + } |
| 2937 | + |
| 2938 | + template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3> |
| 2939 | + XSIMD_INLINE batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, |
| 2940 | + batch_constant<uint32_t, A, V0, V1, V2, V3> mask, |
| 2941 | + requires_arch<neon>) noexcept |
| 2942 | + { |
| 2943 | + constexpr bool is_identity = detail::is_identity(mask); |
| 2944 | + constexpr bool is_dup_lo = detail::is_dup_lo(mask); |
| 2945 | + constexpr bool is_dup_hi = detail::is_dup_hi(mask); |
| 2946 | + |
| 2947 | + XSIMD_IF_CONSTEXPR(is_identity) |
| 2948 | + { |
| 2949 | + return self; |
| 2950 | + } |
| 2951 | + XSIMD_IF_CONSTEXPR(is_dup_lo) |
| 2952 | + { |
| 2953 | + XSIMD_IF_CONSTEXPR(V0 == 0 && V1 == 1) |
| 2954 | + { |
| 2955 | + return vreinterpretq_u32_u64(vdupq_lane_u64(vget_low_u64(vreinterpretq_u64_u32(self)), 0)); |
| 2956 | + } |
| 2957 | + XSIMD_IF_CONSTEXPR(V0 == 1 && V1 == 0) |
| 2958 | + { |
| 2959 | + return vreinterpretq_u32_u64(vdupq_lane_u64(vreinterpret_u64_u32(vrev64_u32(vget_low_u32(self))), 0)); |
| 2960 | + } |
| 2961 | + return vdupq_n_u32(vgetq_lane_u32(self, V0)); |
| 2962 | + } |
| 2963 | + XSIMD_IF_CONSTEXPR(is_dup_hi) |
| 2964 | + { |
| 2965 | + XSIMD_IF_CONSTEXPR(V0 == 2 && V1 == 3) |
| 2966 | + { |
| 2967 | + return vreinterpretq_u32_u64(vdupq_lane_u64(vget_high_u64(vreinterpretq_u64_u32(self)), 0)); |
| 2968 | + } |
| 2969 | + XSIMD_IF_CONSTEXPR(V0 == 3 && V1 == 2) |
| 2970 | + { |
| 2971 | + return vreinterpretq_u32_u64(vdupq_lane_u64(vreinterpret_u64_u32(vrev64_u32(vget_high_u32(self))), 0)); |
| 2972 | + } |
| 2973 | + return vdupq_n_u32(vgetq_lane_u32(self, V0)); |
| 2974 | + } |
| 2975 | + XSIMD_IF_CONSTEXPR(V0 < 2 && V1 < 2 && V2 < 2 && V3 < 2) |
| 2976 | + { |
| 2977 | + uint8x8_t low = vreinterpret_u8_u64(vget_low_u64(vreinterpretq_u64_u32(self))); |
| 2978 | + uint8x8_t mask_lo = detail::make_mask<V0, V1>(); |
| 2979 | + uint8x8_t mask_hi = detail::make_mask<V2, V3>(); |
| 2980 | + uint8x8_t lo = vtbl1_u8(low, mask_lo); |
| 2981 | + uint8x8_t hi = vtbl1_u8(low, mask_hi); |
| 2982 | + return vreinterpretq_u32_u8(vcombine_u8(lo, hi)); |
| 2983 | + } |
| 2984 | + XSIMD_IF_CONSTEXPR(V0 >= 2 && V1 >= 2 && V2 >= 2 && V3 >= 2) |
| 2985 | + { |
| 2986 | + uint8x8_t high = vreinterpret_u8_u64(vget_high_u64(vreinterpretq_u64_u32(self))); |
| 2987 | + uint8x8_t mask_lo = detail::make_mask<V0, V1>(); |
| 2988 | + uint8x8_t mask_hi = detail::make_mask<V2, V3>(); |
| 2989 | + uint8x8_t lo = vtbl1_u8(high, mask_lo); |
| 2990 | + uint8x8_t hi = vtbl1_u8(high, mask_hi); |
| 2991 | + return vreinterpretq_u32_u8(vcombine_u8(lo, hi)); |
| 2992 | + } |
| 2993 | + |
| 2994 | + uint8x8_t mask_lo = detail::make_mask<V0, V1>(); |
| 2995 | + uint8x8_t mask_hi = detail::make_mask<V2, V3>(); |
| 2996 | + |
| 2997 | + uint8x8_t low = vreinterpret_u8_u64(vget_low_u64(vreinterpretq_u64_u32(self))); |
| 2998 | + uint8x8_t lol = vtbl1_u8(low, mask_lo); |
| 2999 | + uint8x8_t loh = vtbl1_u8(low, mask_hi); |
| 3000 | + uint32x4_t true_br = vreinterpretq_u32_u8(vcombine_u8(lol, loh)); |
| 3001 | + |
| 3002 | + uint8x8_t high = vreinterpret_u8_u64(vget_high_u64(vreinterpretq_u64_u32(self))); |
| 3003 | + uint8x8_t hil = vtbl1_u8(high, mask_lo); |
| 3004 | + uint8x8_t hih = vtbl1_u8(high, mask_hi); |
| 3005 | + uint32x4_t false_br = vreinterpretq_u32_u8(vcombine_u8(hil, hih)); |
| 3006 | + |
| 3007 | + batch_bool_constant<uint32_t, A, (V0 < 2), (V1 < 2), (V2 < 2), (V3 < 2)> blend_mask; |
| 3008 | + return select(blend_mask, batch<uint32_t, A>(true_br), batch<uint32_t, A>(false_br), A {}); |
| 3009 | + } |
| 3010 | + |
| 3011 | + template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3> |
| 3012 | + XSIMD_INLINE batch<int32_t, A> swizzle(batch<int32_t, A> const& self, |
| 3013 | + batch_constant<int32_t, A, V0, V1, V2, V3> mask, |
| 3014 | + requires_arch<neon>) noexcept |
| 3015 | + { |
| 3016 | + return vreinterpretq_s32_u32(swizzle(vreinterpretq_u32_s32(self), mask, A {})); |
| 3017 | + } |
| 3018 | + |
| 3019 | + template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3> |
| 3020 | + XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self, |
| 3021 | + batch_constant<uint32_t, A, V0, V1, V2, V3> mask, |
| 3022 | + requires_arch<neon>) noexcept |
| 3023 | + { |
| 3024 | + return vreinterpretq_f32_u32(swizzle(batch<uint32_t, A>(vreinterpretq_u32_f32(self)), mask, A {})); |
| 3025 | + } |
2917 | 3026 | } |
2918 | 3027 |
|
2919 | 3028 | } |
|
0 commit comments