Skip to content

Commit 97bdaa4

Browse files
committed
Add swizzle Avx2 batch_constant for 8/16 bits
1 parent 71a344e commit 97bdaa4

File tree

1 file changed

+117
-4
lines changed

1 file changed

+117
-4
lines changed

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,19 +1141,132 @@ namespace xsimd
11411141
return bitwise_cast<T>(swizzle(bitwise_cast<uint16_t>(self), mask, req));
11421142
}
11431143

1144+
namespace detail
1145+
{
1146+
template <typename T>
1147+
constexpr T swizzle_val_none()
1148+
{
1149+
// Most significant bit of the byte must be 1
1150+
return 0x80;
1151+
}
1152+
1153+
template <typename T>
1154+
constexpr bool swizzle_val_is_cross_lane(T val, T idx, T size)
1155+
{
1156+
return (idx < (size / 2)) != (val < (size / 2));
1157+
}
1158+
1159+
template <typename T>
1160+
constexpr bool swizzle_val_is_defined(T val, T size)
1161+
{
1162+
return (0 <= val) && (val < size);
1163+
}
1164+
1165+
template <typename T>
1166+
constexpr T swizzle_self_val(T val, T idx, T size)
1167+
{
1168+
return (swizzle_val_is_defined(val, size) && !swizzle_val_is_cross_lane(val, idx, size))
1169+
? val % (size / 2)
1170+
: swizzle_val_none<T>();
1171+
}
1172+
1173+
template <typename T, typename A, T... Vals, std::size_t... Ids>
1174+
constexpr auto swizzle_make_self_batch_impl(::xsimd::detail::index_sequence<Ids...>)
1175+
-> batch_constant<T, A, swizzle_self_val(Vals, T(Ids), static_cast<T>(sizeof...(Vals)))...>
1176+
{
1177+
return {};
1178+
}
1179+
1180+
template <typename T, typename A, T... Vals>
1181+
constexpr auto swizzle_make_self_batch()
1182+
-> decltype(swizzle_make_self_batch_impl<T, A, Vals...>(::xsimd::detail::make_index_sequence<sizeof...(Vals)>()))
1183+
{
1184+
return {};
1185+
}
1186+
1187+
template <typename T>
1188+
constexpr T swizzle_cross_val(T val, T idx, T size)
1189+
{
1190+
return (swizzle_val_is_defined(val, size) && swizzle_val_is_cross_lane(val, idx, size))
1191+
? val % (size / 2)
1192+
: swizzle_val_none<T>();
1193+
}
1194+
1195+
template <typename T, typename A, T... Vals, std::size_t... Ids>
1196+
constexpr auto swizzle_make_cross_batch_impl(::xsimd::detail::index_sequence<Ids...>)
1197+
-> batch_constant<T, A, swizzle_cross_val(Vals, T(Ids), static_cast<T>(sizeof...(Vals)))...>
1198+
{
1199+
return {};
1200+
}
1201+
1202+
template <typename T, typename A, T... Vals>
1203+
constexpr auto swizzle_make_cross_batch()
1204+
-> decltype(swizzle_make_cross_batch_impl<T, A, Vals...>(::xsimd::detail::make_index_sequence<sizeof...(Vals)>()))
1205+
{
1206+
return {};
1207+
}
1208+
}
1209+
11441210
// swizzle (constant mask)
1211+
template <class A, uint8_t... Vals>
1212+
XSIMD_INLINE batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch_constant<uint8_t, A, Vals...> mask, requires_arch<avx2>) noexcept
1213+
{
1214+
static_assert(sizeof...(Vals) == 32, "Must contain as many uint8_t as can fit in avx register");
1215+
1216+
XSIMD_IF_CONSTEXPR(detail::is_identity(mask))
1217+
{
1218+
return self;
1219+
}
1220+
1221+
// swap lanes
1222+
__m256i swapped = _mm256_permute2x128_si256(self, self, 0x01); // [high | low]
1223+
1224+
// We can outsmart the dynamic version by creating a compile-time mask that leaves zeros
1225+
// where it does not need to select data, resulting in a simple OR merge of the two batches.
1226+
constexpr auto self_mask = detail::swizzle_make_self_batch<uint8_t, A, Vals...>();
1227+
constexpr auto cross_mask = detail::swizzle_make_cross_batch<uint8_t, A, Vals...>();
1228+
1229+
// permute bytes within each lane (AVX2 only)
1230+
__m256i r0 = _mm256_shuffle_epi8(self, self_mask.as_batch());
1231+
__m256i r1 = _mm256_shuffle_epi8(swapped, cross_mask.as_batch());
1232+
1233+
return xsimd::batch<uint8_t, A>(_mm256_or_si256(r0, r1));
1234+
}
1235+
11451236
template <class A, typename T, uint8_t... Vals, detail::enable_sized_t<T, 1> = 0>
1146-
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint8_t, A, Vals...> mask, requires_arch<avx2> req) noexcept
1237+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint8_t, A, Vals...> const& mask, requires_arch<avx2> req) noexcept
11471238
{
11481239
static_assert(sizeof...(Vals) == 32, "Must contain as many uint8_t as can fit in avx register");
1149-
return swizzle(self, mask.as_batch(), req);
1240+
return bitwise_cast<T>(swizzle(bitwise_cast<uint8_t>(self), mask, req));
1241+
}
1242+
1243+
template <
1244+
class A,
1245+
uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3,
1246+
uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7,
1247+
uint16_t V8, uint16_t V9, uint16_t V10, uint16_t V11,
1248+
uint16_t V12, uint16_t V13, uint16_t V14, uint16_t V15>
1249+
XSIMD_INLINE batch<uint16_t, A> swizzle(
1250+
batch<uint16_t, A> const& self,
1251+
batch_constant<uint16_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15>,
1252+
requires_arch<avx2> req) noexcept
1253+
{
1254+
const auto self_bytes = bitwise_cast<uint8_t>(self);
1255+
// If a mask entry is k, we want 2k in low byte and 2k+1 in high byte
1256+
auto constexpr mask_2k_2kp1 = batch_constant<
1257+
uint8_t, A,
1258+
2 * V0, 2 * V0 + 1, 2 * V1, 2 * V1 + 1, 2 * V2, 2 * V2 + 1, 2 * V3, 2 * V3 + 1,
1259+
2 * V4, 2 * V4 + 1, 2 * V5, 2 * V5 + 1, 2 * V6, 2 * V6 + 1, 2 * V7, 2 * V7 + 1,
1260+
2 * V8, 2 * V8 + 1, 2 * V9, 2 * V9 + 1, 2 * V10, 2 * V10 + 1, 2 * V11, 2 * V11 + 1,
1261+
2 * V12, 2 * V12 + 1, 2 * V13, 2 * V13 + 1, 2 * V14, 2 * V14 + 1, 2 * V15, 2 * V15 + 1> {};
1262+
return bitwise_cast<uint16_t>(swizzle(self_bytes, mask_2k_2kp1, req));
11501263
}
11511264

11521265
template <class A, typename T, uint16_t... Vals, detail::enable_sized_t<T, 2> = 0>
1153-
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint16_t, A, Vals...> mask, requires_arch<avx2> req) noexcept
1266+
XSIMD_INLINE batch<T, A> swizzle(batch<T, A> const& self, batch_constant<uint16_t, A, Vals...> const& mask, requires_arch<avx2> req) noexcept
11541267
{
11551268
static_assert(sizeof...(Vals) == 16, "Must contain as many uint16_t as can fit in avx register");
1156-
return swizzle(self, mask.as_batch(), req);
1269+
return bitwise_cast<T>(swizzle(bitwise_cast<uint16_t>(self), mask, req));
11571270
}
11581271

11591272
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>

0 commit comments

Comments
 (0)