Skip to content

Commit e089814

Browse files
committed
fixing sse
1 parent b10fc1d commit e089814

File tree

1 file changed

+36
-11
lines changed

1 file changed

+36
-11
lines changed

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,20 +1643,45 @@ namespace xsimd
16431643
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
16441644
XSIMD_INLINE batch<uint16_t, A> swizzle(batch<uint16_t, A> const& self, batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7>, requires_arch<sse2>) noexcept
16451645
{
1646-
// permute within each lane
1647-
constexpr auto mask_lo = detail::mod_shuffle(V0, V1, V2, V3);
1648-
constexpr auto mask_hi = detail::mod_shuffle(V4, V5, V6, V7);
1649-
__m128i lo = _mm_shufflelo_epi16(self, mask_lo);
1650-
__m128i hi = _mm_shufflehi_epi16(self, mask_hi);
1646+
__m128i v = self;
16511647

1652-
__m128i lo_lo = _mm_castpd_si128(_mm_shuffle_pd(_mm_castsi128_pd(lo), _mm_castsi128_pd(lo), _MM_SHUFFLE2(0, 0)));
1653-
__m128i hi_hi = _mm_castpd_si128(_mm_shuffle_pd(_mm_castsi128_pd(hi), _mm_castsi128_pd(hi), _MM_SHUFFLE2(1, 1)));
1648+
// 1) Shuffle the low 64-bit half for lanes 0–3 and 4–7:
1649+
constexpr int imm_lo0 = detail::mod_shuffle(V0, V1, V2, V3);
1650+
constexpr int imm_lo1 = detail::mod_shuffle(V4, V5, V6, V7);
1651+
__m128i lo0 = _mm_shufflelo_epi16(v, imm_lo0);
1652+
__m128i lo1 = _mm_shufflelo_epi16(v, imm_lo1);
16541653

1655-
// mask to choose the right lane
1656-
batch_bool_constant<uint16_t, A, (V0 < 4), (V1 < 4), (V2 < 4), (V3 < 4), (V4 < 4), (V5 < 4), (V6 < 4), (V7 < 4)> blend_mask;
1654+
// Broadcast each low-half permutation across both 64-bit halves:
1655+
__m128i lo0_all = _mm_unpacklo_epi64(lo0, lo0);
1656+
__m128i lo1_all = _mm_unpacklo_epi64(lo1, lo1);
16571657

1658-
// blend the two permutes
1659-
return select(blend_mask, batch<uint16_t, A>(lo_lo), batch<uint16_t, A>(hi_hi));
1658+
// 2) Shuffle the high 64-bit half for lanes 0–3 and 4–7:
1659+
constexpr int imm_hi0 = detail::mod_shuffle(V0 - 4, V1 - 4, V2 - 4, V3 - 4);
1660+
constexpr int imm_hi1 = detail::mod_shuffle(V4 - 4, V5 - 4, V6 - 4, V7 - 4);
1661+
__m128i hi0 = _mm_shufflehi_epi16(v, imm_hi0);
1662+
__m128i hi1 = _mm_shufflehi_epi16(v, imm_hi1);
1663+
1664+
// Broadcast each high-half permutation across both 64-bit halves:
1665+
__m128i hi0_all = _mm_unpackhi_epi64(hi0, hi0);
1666+
__m128i hi1_all = _mm_unpackhi_epi64(hi1, hi1);
1667+
1668+
// 3) Merge the two “low” broadcasts into one vector (lanes 0–3 ← lo0_all, lanes 4–7 ← lo1_all)
1669+
__m128i low_all = _mm_unpacklo_epi64(lo0_all, lo1_all); // { lo0, lo1 }
1670+
1671+
// constexpr batch_bool_constant<uint16_t, A, false, false, false, false, true, true, true, true> group_mask {};
1672+
// auto low_all = select(group_mask, batch<uint16_t, A>(lo1_all), batch<uint16_t, A>(lo0_all));
1673+
1674+
// Likewise merge the two “high” broadcasts:
1675+
__m128i high_all = _mm_unpacklo_epi64(hi0_all, hi1_all); // { hi0, hi1 }
1676+
1677+
// auto high_all = select(group_mask, batch<uint16_t, A>(hi1_all), batch<uint16_t, A>(hi0_all));
1678+
1679+
// 4) Finally, pick per-lane: if Vn<4 → take from low_all, else from high_all
1680+
constexpr batch_bool_constant<uint16_t, A, (V0 < 4), (V1 < 4), (V2 < 4), (V3 < 4), (V4 < 4), (V5 < 4), (V6 < 4), (V7 < 4)> lane_mask {};
1681+
return select(lane_mask, // mask[i] ? low_all[i] : high_all[i]
1682+
batch<uint16_t, A>(low_all),
1683+
batch<uint16_t, A>(high_all));
1684+
// return select(lane_mask, low_all, high_all);
16601685
}
16611686

16621687
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>

0 commit comments

Comments
 (0)