Skip to content

Commit ba8a4bd

Browse files
committed
improved sse2 swizzle
1 parent e089814 commit ba8a4bd

File tree

1 file changed

+62
-34
lines changed

1 file changed

+62
-34
lines changed

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1641,55 +1641,84 @@ namespace xsimd
16411641
}
16421642

16431643
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
1644-
XSIMD_INLINE batch<uint16_t, A> swizzle(batch<uint16_t, A> const& self, batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7>, requires_arch<sse2>) noexcept
1644+
XSIMD_INLINE batch<int16_t, A> swizzle(batch<int16_t, A> const& self, batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7>, requires_arch<sse2>) noexcept
16451645
{
1646-
__m128i v = self;
1647-
1646+
// 0) identity?
1647+
constexpr bool is_identity = (V0 == 0 && V1 == 1 && V2 == 2 && V3 == 3 && V4 == 4 && V5 == 5 && V6 == 6 && V7 == 7);
1648+
XSIMD_IF_CONSTEXPR(is_identity)
1649+
{
1650+
return self;
1651+
}
1652+
// 1) duplicate‐low‐half? (lanes 0–3 from low half, and 4–7 the same)
1653+
constexpr bool is_dup_lo = (V0 < 4 && V1 < 4 && V2 < 4 && V3 < 4) && V4 == V0 && V5 == V1 && V6 == V2 && V7 == V3;
1654+
XSIMD_IF_CONSTEXPR(is_dup_lo)
1655+
{
1656+
// permute the low half
1657+
constexpr int imm = detail::mod_shuffle(V0, V1, V2, V3);
1658+
const auto lo = _mm_shufflelo_epi16(self, imm);
1659+
// broadcast that 64-bit low half into both halves
1660+
const auto lo_all = _mm_unpacklo_epi64(lo, lo);
1661+
return lo_all;
1662+
}
1663+
// 2) duplicate‐high‐half? (lanes 0–3 from high half, 4–7 the same)
1664+
constexpr bool is_dup_hi = (V0 >= 4 && V0 < 8 && V1 >= 4 && V1 < 8 && V2 >= 4 && V2 < 8 && V3 >= 4 && V3 < 8) && V4 == V0 && V5 == V1 && V6 == V2 && V7 == V3;
1665+
XSIMD_IF_CONSTEXPR(is_dup_hi)
1666+
{
1667+
// permute the high half (indices %4)
1668+
constexpr int imm = detail::mod_shuffle(V0 - 4, V1 - 4, V2 - 4, V3 - 4);
1669+
const auto hi = _mm_shufflehi_epi16(self, imm);
1670+
// broadcast that 64-bit high half into both halves
1671+
const auto hi_all = _mm_unpackhi_epi64(hi, hi);
1672+
return hi_all;
1673+
}
16481674
// 1) Shuffle the low 64-bit half for lanes 0–3 and 4–7:
16491675
constexpr int imm_lo0 = detail::mod_shuffle(V0, V1, V2, V3);
16501676
constexpr int imm_lo1 = detail::mod_shuffle(V4, V5, V6, V7);
1651-
__m128i lo0 = _mm_shufflelo_epi16(v, imm_lo0);
1652-
__m128i lo1 = _mm_shufflelo_epi16(v, imm_lo1);
1653-
1677+
const auto lo0 = _mm_shufflelo_epi16(self, imm_lo0);
1678+
const auto lo1 = _mm_shufflelo_epi16(self, imm_lo1);
16541679
// Broadcast each low-half permutation across both 64-bit halves:
1655-
__m128i lo0_all = _mm_unpacklo_epi64(lo0, lo0);
1656-
__m128i lo1_all = _mm_unpacklo_epi64(lo1, lo1);
1657-
1680+
const auto lo0_all = _mm_unpacklo_epi64(lo0, lo0);
1681+
const auto lo1_all = _mm_unpacklo_epi64(lo1, lo1);
16581682
// 2) Shuffle the high 64-bit half for lanes 0–3 and 4–7:
1659-
constexpr int imm_hi0 = detail::mod_shuffle(V0 - 4, V1 - 4, V2 - 4, V3 - 4);
1660-
constexpr int imm_hi1 = detail::mod_shuffle(V4 - 4, V5 - 4, V6 - 4, V7 - 4);
1661-
__m128i hi0 = _mm_shufflehi_epi16(v, imm_hi0);
1662-
__m128i hi1 = _mm_shufflehi_epi16(v, imm_hi1);
16631683

16641684
// Broadcast each high-half permutation across both 64-bit halves:
1665-
__m128i hi0_all = _mm_unpackhi_epi64(hi0, hi0);
1666-
__m128i hi1_all = _mm_unpackhi_epi64(hi1, hi1);
1685+
// hi0_all: only instantiated if any of V0..V3 >= 4
1686+
const auto hi0_all = [&]
1687+
{
1688+
XSIMD_IF_CONSTEXPR(!(V0 < 4 && V1 < 4 && V2 < 4 && V3 < 4))
1689+
{
1690+
constexpr int imm_hi0 = detail::mod_shuffle(V0 - 4, V1 - 4, V2 - 4, V3 - 4);
1691+
__m128i hi0 = _mm_shufflehi_epi16(self, imm_hi0);
1692+
return _mm_unpackhi_epi64(hi0, hi0);
1693+
}
1694+
// not used whenever all V0..V3<4
1695+
return _mm_setzero_si128();
1696+
}();
16671697

1698+
// hi1_all: only instantiated if any of V4..V7 >= 4
1699+
const auto hi1_all = [&]
1700+
{
1701+
XSIMD_IF_CONSTEXPR(!(V4 < 4 && V5 < 4 && V6 < 4 && V7 < 4))
1702+
{
1703+
constexpr int imm_hi1 = detail::mod_shuffle(V4 - 4, V5 - 4, V6 - 4, V7 - 4);
1704+
__m128i hi1 = _mm_shufflehi_epi16(self, imm_hi1);
1705+
return _mm_unpackhi_epi64(hi1, hi1);
1706+
}
1707+
return _mm_setzero_si128();
1708+
}();
16681709
// 3) Merge the two “low” broadcasts into one vector (lanes 0–3 ← lo0_all, lanes 4–7 ← lo1_all)
1669-
__m128i low_all = _mm_unpacklo_epi64(lo0_all, lo1_all); // { lo0, lo1 }
1670-
1671-
// constexpr batch_bool_constant<uint16_t, A, false, false, false, false, true, true, true, true> group_mask {};
1672-
// auto low_all = select(group_mask, batch<uint16_t, A>(lo1_all), batch<uint16_t, A>(lo0_all));
1673-
1710+
const auto low_all = _mm_unpacklo_epi64(lo0_all, lo1_all); // { lo0, lo1 }
16741711
// Likewise merge the two “high” broadcasts:
1675-
__m128i high_all = _mm_unpacklo_epi64(hi0_all, hi1_all); // { hi0, hi1 }
1676-
1677-
// auto high_all = select(group_mask, batch<uint16_t, A>(hi1_all), batch<uint16_t, A>(hi0_all));
1678-
1712+
const auto high_all = _mm_unpacklo_epi64(hi0_all, hi1_all); // { hi0, hi1 }
16791713
// 4) Finally, pick per-lane: if Vn<4 → take from low_all, else from high_all
1680-
constexpr batch_bool_constant<uint16_t, A, (V0 < 4), (V1 < 4), (V2 < 4), (V3 < 4), (V4 < 4), (V5 < 4), (V6 < 4), (V7 < 4)> lane_mask {};
1681-
return select(lane_mask, // mask[i] ? low_all[i] : high_all[i]
1682-
batch<uint16_t, A>(low_all),
1683-
batch<uint16_t, A>(high_all));
1684-
// return select(lane_mask, low_all, high_all);
1714+
constexpr batch_bool_constant<int16_t, A, (V0 < 4), (V1 < 4), (V2 < 4), (V3 < 4), (V4 < 4), (V5 < 4), (V6 < 4), (V7 < 4)> lane_mask {};
1715+
return select(lane_mask, batch<int16_t>(low_all), batch<int16_t>(high_all));
16851716
}
1686-
16871717
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
1688-
XSIMD_INLINE batch<int16_t, A> swizzle(batch<int16_t, A> const& self, batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<sse2>) noexcept
1718+
XSIMD_INLINE batch<uint16_t, A> swizzle(batch<uint16_t, A> const& self, batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<sse2>) noexcept
16891719
{
1690-
return bitwise_cast<int16_t>(swizzle(bitwise_cast<uint16_t>(self), mask, sse2 {}));
1720+
return bitwise_cast<uint16_t>(swizzle(bitwise_cast<int16_t>(self), mask, sse2 {}));
16911721
}
1692-
16931722
// transpose
16941723
template <class A>
16951724
XSIMD_INLINE void transpose(batch<float, A>* matrix_begin, batch<float, A>* matrix_end, requires_arch<sse2>) noexcept
@@ -1854,7 +1883,6 @@ namespace xsimd
18541883
return {};
18551884
}
18561885
}
1857-
18581886
}
18591887
}
18601888

0 commit comments

Comments
 (0)