Skip to content

Commit c237a2b

Browse files
committed
improved avx/avx2 swizzles
1 parent d799829 commit c237a2b

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,35 @@ namespace xsimd
14801480
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
14811481
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7>, requires_arch<avx>) noexcept
14821482
{
1483+
// 1) identity?
1484+
constexpr bool is_identity = (V0 == 0 && V1 == 1 && V2 == 2 && V3 == 3 && V4 == 4 && V5 == 5 && V6 == 6 && V7 == 7);
1485+
// 2) duplicate-low half?
1486+
constexpr bool is_dup_lo = ((V0 < 4 && V1 < 4 && V2 < 4 && V3 < 4) && V4 == V0 && V5 == V1 && V6 == V2 && V7 == V3);
1487+
// 3) duplicate-high half?
1488+
constexpr bool is_dup_hi = (V0 >= 4 && V0 <= 7 && V1 >= 4 && V1 <= 7 && V2 >= 4 && V2 <= 7 && V3 >= 4 && V3 <= 7 && V4 == V0 && V5 == V1 && V6 == V2 && V7 == V3);
1489+
1490+
XSIMD_IF_CONSTEXPR(is_identity) { return self; }
1491+
XSIMD_IF_CONSTEXPR(is_dup_lo)
1492+
{
1493+
__m128 lo = _mm256_castps256_ps128(self);
1494+
// if lo is not identity, we can permute it before duplicating
1495+
XSIMD_IF_CONSTEXPR(V0 != 0 || V1 != 1 || V2 != 2 || V3 != 3)
1496+
{
1497+
constexpr int imm = ((V3 & 3) << 6) | ((V2 & 3) << 4) | ((V1 & 3) << 2) | ((V0 & 3) << 0);
1498+
lo = _mm_permute_ps(lo, imm);
1499+
}
1500+
return _mm256_set_m128(lo, lo);
1501+
}
1502+
XSIMD_IF_CONSTEXPR(is_dup_hi)
1503+
{
1504+
__m128 hi = _mm256_extractf128_ps(self, 1);
1505+
XSIMD_IF_CONSTEXPR(V0 != 4 || V1 != 5 || V2 != 6 || V3 != 7)
1506+
{
1507+
constexpr int imm = ((V3 & 3) << 6) | ((V2 & 3) << 4) | ((V1 & 3) << 2) | ((V0 & 3) << 0);
1508+
hi = _mm_permute_ps(hi, imm);
1509+
}
1510+
return _mm256_set_m128(hi, hi);
1511+
}
14831512
// duplicate low and high part of input
14841513
__m256 hi = _mm256_castps128_ps256(_mm256_extractf128_ps(self, 1));
14851514
__m256 hi_hi = _mm256_insertf128_ps(self, _mm256_castps256_ps128(hi), 0);
@@ -1505,6 +1534,16 @@ namespace xsimd
15051534
template <class A, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3>
15061535
XSIMD_INLINE batch<double, A> swizzle(batch<double, A> const& self, batch_constant<uint64_t, A, V0, V1, V2, V3>, requires_arch<avx>) noexcept
15071536
{
1537+
constexpr bool is_identity = V0 == 0 && V1 == 1 && V2 == 2 && V3 == 3;
1538+
constexpr bool can_use_pd = V0 < 2 && V1 < 2 && V2 >= 2 && V3 >= 2; // no lane crossing
1539+
1540+
XSIMD_IF_CONSTEXPR(is_identity) { return self; }
1541+
XSIMD_IF_CONSTEXPR(can_use_pd)
1542+
{
1543+
// build the 4-bit immediate: bit i = 1 if you pick the upper element of pair i
1544+
constexpr int mask = ((V0 & 1) << 0) | ((V1 & 1) << 1) | ((V2 & 1) << 2) | ((V3 & 1) << 3);
1545+
return _mm256_permute_pd(self, mask);
1546+
}
15081547
// duplicate low and high part of input
15091548
__m256d hi = _mm256_castpd128_pd256(_mm256_extractf128_pd(self, 1));
15101549
__m256d hi_hi = _mm256_insertf128_pd(self, _mm256_castpd256_pd128(hi), 0);

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,12 +942,79 @@ namespace xsimd
942942
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
943943
XSIMD_INLINE batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
944944
{
945-
return _mm256_permutevar8x32_ps(self, mask.as_batch());
945+
// 1) identity?
946+
constexpr bool is_identity = (V0 == 0 && V1 == 1 && V2 == 2 && V3 == 3 && V4 == 4 && V5 == 5 && V6 == 6 && V7 == 7);
947+
// 2) all-different mask → full 8-lane permute
948+
constexpr uint32_t bitmask = (1u << (V0 & 7)) | (1u << (V1 & 7)) | (1u << (V2 & 7)) | (1u << (V3 & 7)) | (1u << (V4 & 7)) | (1u << (V5 & 7)) | (1u << (V6 & 7)) | (1u << (V7 & 7));
949+
constexpr bool is_all_different = (bitmask == 0xFFu);
950+
// 3) duplicate-low half?
951+
constexpr bool is_dup_lo = (V0 < 4 && V1 < 4 && V2 < 4 && V3 < 4) && V4 == V0 && V5 == V1 && V6 == V2 && V7 == V3;
952+
// 4) duplicate-high half?
953+
constexpr bool is_dup_hi = (V0 >= 4 && V0 <= 7 && V1 >= 4 && V1 <= 7 && V2 >= 4 && V2 <= 7 && V3 >= 4 && V3 <= 7) && V4 == V0 && V5 == V1 && V6 == V2 && V7 == V3;
954+
955+
XSIMD_IF_CONSTEXPR(is_identity) { return self; }
956+
XSIMD_IF_CONSTEXPR(is_dup_lo)
957+
{
958+
__m128 lo = _mm256_castps256_ps128(self);
959+
// if lo is not identity, we can permute it before duplicating
960+
XSIMD_IF_CONSTEXPR(V0 != 0 || V1 != 1 || V2 != 2 || V3 != 3)
961+
{
962+
constexpr int imm = ((V3 & 3) << 6) | ((V2 & 3) << 4) | ((V1 & 3) << 2) | ((V0 & 3) << 0);
963+
lo = _mm_permute_ps(lo, imm);
964+
}
965+
return _mm256_set_m128(lo, lo);
966+
}
967+
XSIMD_IF_CONSTEXPR(is_dup_hi)
968+
{
969+
__m128 hi = _mm256_extractf128_ps(self, 1);
970+
XSIMD_IF_CONSTEXPR(V0 != 4 || V1 != 5 || V2 != 6 || V3 != 7)
971+
{
972+
constexpr int imm = ((V3 & 3) << 6) | ((V2 & 3) << 4) | ((V1 & 3) << 2) | ((V0 & 3) << 0);
973+
hi = _mm_permute_ps(hi, imm);
974+
}
975+
return _mm256_set_m128(hi, hi);
976+
}
977+
XSIMD_IF_CONSTEXPR(is_all_different)
978+
{
979+
// The intrinsic does NOT allow to copy the same element of the source vector to more than one element of the destination vector.
980+
// one-shot 8-lane permute
981+
return _mm256_permutevar8x32_ps(self, mask.as_batch());
982+
}
983+
// duplicate low and high part of input
984+
__m256 hi = _mm256_castps128_ps256(_mm256_extractf128_ps(self, 1));
985+
__m256 hi_hi = _mm256_insertf128_ps(self, _mm256_castps256_ps128(hi), 0);
986+
987+
__m256 low = _mm256_castps128_ps256(_mm256_castps256_ps128(self));
988+
__m256 low_low = _mm256_insertf128_ps(self, _mm256_castps256_ps128(low), 1);
989+
990+
// normalize mask
991+
batch_constant<uint32_t, A, (V0 % 4), (V1 % 4), (V2 % 4), (V3 % 4), (V4 % 4), (V5 % 4), (V6 % 4), (V7 % 4)> half_mask;
992+
993+
// permute within each lane
994+
__m256 r0 = _mm256_permutevar_ps(low_low, half_mask.as_batch());
995+
__m256 r1 = _mm256_permutevar_ps(hi_hi, half_mask.as_batch());
996+
997+
// mask to choose the right lane
998+
batch_bool_constant<uint32_t, A, (V0 >= 4), (V1 >= 4), (V2 >= 4), (V3 >= 4), (V4 >= 4), (V5 >= 4), (V6 >= 4), (V7 >= 4)> blend_mask;
999+
1000+
// blend the two permutes
1001+
return _mm256_blend_ps(r0, r1, blend_mask.mask());
9461002
}
9471003

9481004
template <class A, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3>
9491005
XSIMD_INLINE batch<double, A> swizzle(batch<double, A> const& self, batch_constant<uint64_t, A, V0, V1, V2, V3>, requires_arch<avx2>) noexcept
9501006
{
1007+
constexpr bool is_identity = (V0 == 0 && V1 == 1 && V2 == 2 && V3 == 3);
1008+
constexpr bool can_use_pd = (V0 < 2 && V1 < 2 && V2 >= 2 && V2 < 4 && V3 >= 2 && V3 < 4);
1009+
1010+
XSIMD_IF_CONSTEXPR(is_identity) { return self; }
1011+
XSIMD_IF_CONSTEXPR(can_use_pd)
1012+
{
1013+
// build the 4-bit immediate: bit i = 1 if you pick the upper element of pair i
1014+
constexpr int mask = ((V0 & 1) << 0) | ((V1 & 1) << 1) | ((V2 & 1) << 2) | ((V3 & 1) << 3);
1015+
return _mm256_permute_pd(self, mask);
1016+
}
1017+
// fallback to full 4-element permute
9511018
constexpr auto mask = detail::shuffle(V0, V1, V2, V3);
9521019
return _mm256_permute4x64_pd(self, mask);
9531020
}

0 commit comments

Comments
 (0)