Skip to content

Commit ce12df2

Browse files
committed
fixing sse
1 parent 2aa1420 commit ce12df2

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,20 +1690,45 @@ namespace xsimd
16901690
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
16911691
XSIMD_INLINE batch<uint16_t, A> swizzle(batch<uint16_t, A> const& self, batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7>, requires_arch<sse2>) noexcept
16921692
{
1693-
// permute within each lane
1694-
constexpr auto mask_lo = detail::mod_shuffle(V0, V1, V2, V3);
1695-
constexpr auto mask_hi = detail::mod_shuffle(V4, V5, V6, V7);
1696-
__m128i lo = _mm_shufflelo_epi16(self, mask_lo);
1697-
__m128i hi = _mm_shufflehi_epi16(self, mask_hi);
1693+
__m128i v = self;
16981694

1699-
__m128i lo_lo = _mm_castpd_si128(_mm_shuffle_pd(_mm_castsi128_pd(lo), _mm_castsi128_pd(lo), _MM_SHUFFLE2(0, 0)));
1700-
__m128i hi_hi = _mm_castpd_si128(_mm_shuffle_pd(_mm_castsi128_pd(hi), _mm_castsi128_pd(hi), _MM_SHUFFLE2(1, 1)));
1695+
// 1) Shuffle the low 64-bit half for lanes 0–3 and 4–7:
1696+
constexpr int imm_lo0 = detail::mod_shuffle(V0, V1, V2, V3);
1697+
constexpr int imm_lo1 = detail::mod_shuffle(V4, V5, V6, V7);
1698+
__m128i lo0 = _mm_shufflelo_epi16(v, imm_lo0);
1699+
__m128i lo1 = _mm_shufflelo_epi16(v, imm_lo1);
17011700

1702-
// mask to choose the right lane
1703-
batch_bool_constant<uint16_t, A, (V0 < 4), (V1 < 4), (V2 < 4), (V3 < 4), (V4 < 4), (V5 < 4), (V6 < 4), (V7 < 4)> blend_mask;
1701+
// Broadcast each low-half permutation across both 64-bit halves:
1702+
__m128i lo0_all = _mm_unpacklo_epi64(lo0, lo0);
1703+
__m128i lo1_all = _mm_unpacklo_epi64(lo1, lo1);
17041704

1705-
// blend the two permutes
1706-
return select(blend_mask, batch<uint16_t, A>(lo_lo), batch<uint16_t, A>(hi_hi));
1705+
// 2) Shuffle the high 64-bit half for lanes 0–3 and 4–7:
1706+
constexpr int imm_hi0 = detail::mod_shuffle(V0 - 4, V1 - 4, V2 - 4, V3 - 4);
1707+
constexpr int imm_hi1 = detail::mod_shuffle(V4 - 4, V5 - 4, V6 - 4, V7 - 4);
1708+
__m128i hi0 = _mm_shufflehi_epi16(v, imm_hi0);
1709+
__m128i hi1 = _mm_shufflehi_epi16(v, imm_hi1);
1710+
1711+
// Broadcast each high-half permutation across both 64-bit halves:
1712+
__m128i hi0_all = _mm_unpackhi_epi64(hi0, hi0);
1713+
__m128i hi1_all = _mm_unpackhi_epi64(hi1, hi1);
1714+
1715+
// 3) Merge the two “low” broadcasts into one vector (lanes 0–3 ← lo0_all, lanes 4–7 ← lo1_all)
1716+
__m128i low_all = _mm_unpacklo_epi64(lo0_all, lo1_all); // { lo0, lo1 }
1717+
1718+
// constexpr batch_bool_constant<uint16_t, A, false, false, false, false, true, true, true, true> group_mask {};
1719+
// auto low_all = select(group_mask, batch<uint16_t, A>(lo1_all), batch<uint16_t, A>(lo0_all));
1720+
1721+
// Likewise merge the two “high” broadcasts:
1722+
__m128i high_all = _mm_unpacklo_epi64(hi0_all, hi1_all); // { hi0, hi1 }
1723+
1724+
// auto high_all = select(group_mask, batch<uint16_t, A>(hi1_all), batch<uint16_t, A>(hi0_all));
1725+
1726+
// 4) Finally, pick per-lane: if Vn<4 → take from low_all, else from high_all
1727+
constexpr batch_bool_constant<uint16_t, A, (V0 < 4), (V1 < 4), (V2 < 4), (V3 < 4), (V4 < 4), (V5 < 4), (V6 < 4), (V7 < 4)> lane_mask {};
1728+
return select(lane_mask, // mask[i] ? low_all[i] : high_all[i]
1729+
batch<uint16_t, A>(low_all),
1730+
batch<uint16_t, A>(high_all));
1731+
// return select(lane_mask, low_all, high_all);
17071732
}
17081733

17091734
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>

0 commit comments

Comments
 (0)