@@ -1141,19 +1141,132 @@ namespace xsimd
11411141 return bitwise_cast<T>(swizzle (bitwise_cast<uint16_t >(self), mask, req));
11421142 }
11431143
1144+ namespace detail
1145+ {
1146+ template <typename T>
1147+ constexpr T swizzle_val_none ()
1148+ {
1149+ // Most significant bit of the byte must be 1
1150+ return 0x80 ;
1151+ }
1152+
1153+ template <typename T>
1154+ constexpr bool swizzle_val_is_cross_lane (T val, T idx, T size)
1155+ {
1156+ return (idx < (size / 2 )) != (val < (size / 2 ));
1157+ }
1158+
1159+ template <typename T>
1160+ constexpr bool swizzle_val_is_defined (T val, T size)
1161+ {
1162+ return (0 <= val) && (val < size);
1163+ }
1164+
1165+ template <typename T>
1166+ constexpr T swizzle_self_val (T val, T idx, T size)
1167+ {
1168+ return (swizzle_val_is_defined (val, size) && !swizzle_val_is_cross_lane (val, idx, size))
1169+ ? val % (size / 2 )
1170+ : swizzle_val_none<T>();
1171+ }
1172+
1173+ template <typename T, typename A, T... Vals, std::size_t ... Ids>
1174+ constexpr auto swizzle_make_self_batch_impl (::xsimd::detail::index_sequence<Ids...>)
1175+ -> batch_constant<T, A, swizzle_self_val(Vals, T(Ids), static_cast<T>(sizeof ...(Vals)))...>
1176+ {
1177+ return {};
1178+ }
1179+
1180+ template <typename T, typename A, T... Vals>
1181+ constexpr auto swizzle_make_self_batch ()
1182+ -> decltype(swizzle_make_self_batch_impl<T, A, Vals...>(::xsimd::detail::make_index_sequence<sizeof ...(Vals)>()))
1183+ {
1184+ return {};
1185+ }
1186+
1187+ template <typename T>
1188+ constexpr T swizzle_cross_val (T val, T idx, T size)
1189+ {
1190+ return (swizzle_val_is_defined (val, size) && swizzle_val_is_cross_lane (val, idx, size))
1191+ ? val % (size / 2 )
1192+ : swizzle_val_none<T>();
1193+ }
1194+
1195+ template <typename T, typename A, T... Vals, std::size_t ... Ids>
1196+ constexpr auto swizzle_make_cross_batch_impl (::xsimd::detail::index_sequence<Ids...>)
1197+ -> batch_constant<T, A, swizzle_cross_val(Vals, T(Ids), static_cast<T>(sizeof ...(Vals)))...>
1198+ {
1199+ return {};
1200+ }
1201+
1202+ template <typename T, typename A, T... Vals>
1203+ constexpr auto swizzle_make_cross_batch ()
1204+ -> decltype(swizzle_make_cross_batch_impl<T, A, Vals...>(::xsimd::detail::make_index_sequence<sizeof ...(Vals)>()))
1205+ {
1206+ return {};
1207+ }
1208+ }
1209+
11441210 // swizzle (constant mask)
1211+ template <class A , uint8_t ... Vals>
1212+ XSIMD_INLINE batch<uint8_t , A> swizzle (batch<uint8_t , A> const & self, batch_constant<uint8_t , A, Vals...> mask, requires_arch<avx2>) noexcept
1213+ {
1214+ static_assert (sizeof ...(Vals) == 32 , " Must contain as many uint8_t as can fit in avx register" );
1215+
1216+ XSIMD_IF_CONSTEXPR (detail::is_identity (mask))
1217+ {
1218+ return self;
1219+ }
1220+
1221+ // swap lanes
1222+ __m256i swapped = _mm256_permute2x128_si256 (self, self, 0x01 ); // [high | low]
1223+
1224+ // We can outsmart the dynamic version by creating a compile-time mask that leaves zeros
1225+ // where it does not need to select data, resulting in a simple OR merge of the two batches.
1226+ constexpr auto self_mask = detail::swizzle_make_self_batch<uint8_t , A, Vals...>();
1227+ constexpr auto cross_mask = detail::swizzle_make_cross_batch<uint8_t , A, Vals...>();
1228+
1229+ // permute bytes within each lane (AVX2 only)
1230+ __m256i r0 = _mm256_shuffle_epi8 (self, self_mask.as_batch ());
1231+ __m256i r1 = _mm256_shuffle_epi8 (swapped, cross_mask.as_batch ());
1232+
1233+ return xsimd::batch<uint8_t , A>(_mm256_or_si256 (r0, r1));
1234+ }
1235+
11451236 template <class A , typename T, uint8_t ... Vals, detail::enable_sized_t <T, 1 > = 0 >
1146- XSIMD_INLINE batch<T, A> swizzle (batch<T, A> const & self, batch_constant<uint8_t , A, Vals...> mask, requires_arch<avx2> req) noexcept
1237+ XSIMD_INLINE batch<T, A> swizzle (batch<T, A> const & self, batch_constant<uint8_t , A, Vals...> const & mask, requires_arch<avx2> req) noexcept
11471238 {
11481239 static_assert (sizeof ...(Vals) == 32 , " Must contain as many uint8_t as can fit in avx register" );
1149- return swizzle (self, mask.as_batch (), req);
1240+ return bitwise_cast<T>(swizzle (bitwise_cast<uint8_t >(self), mask, req));
1241+ }
1242+
1243+ template <
1244+ class A ,
1245+ uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3,
1246+ uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7,
1247+ uint16_t V8, uint16_t V9, uint16_t V10, uint16_t V11,
1248+ uint16_t V12, uint16_t V13, uint16_t V14, uint16_t V15>
1249+ XSIMD_INLINE batch<uint16_t , A> swizzle (
1250+ batch<uint16_t , A> const & self,
1251+ batch_constant<uint16_t , A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15>,
1252+ requires_arch<avx2> req) noexcept
1253+ {
1254+ const auto self_bytes = bitwise_cast<uint8_t >(self);
1255+ // If a mask entry is k, we want 2k in low byte and 2k+1 in high byte
1256+ auto constexpr mask_2k_2kp1 = batch_constant<
1257+ uint8_t , A,
1258+ 2 * V0, 2 * V0 + 1 , 2 * V1, 2 * V1 + 1 , 2 * V2, 2 * V2 + 1 , 2 * V3, 2 * V3 + 1 ,
1259+ 2 * V4, 2 * V4 + 1 , 2 * V5, 2 * V5 + 1 , 2 * V6, 2 * V6 + 1 , 2 * V7, 2 * V7 + 1 ,
1260+ 2 * V8, 2 * V8 + 1 , 2 * V9, 2 * V9 + 1 , 2 * V10, 2 * V10 + 1 , 2 * V11, 2 * V11 + 1 ,
1261+ 2 * V12, 2 * V12 + 1 , 2 * V13, 2 * V13 + 1 , 2 * V14, 2 * V14 + 1 , 2 * V15, 2 * V15 + 1 > {};
1262+ return bitwise_cast<uint16_t >(swizzle (self_bytes, mask_2k_2kp1, req));
11501263 }
11511264
11521265 template <class A , typename T, uint16_t ... Vals, detail::enable_sized_t <T, 2 > = 0 >
1153- XSIMD_INLINE batch<T, A> swizzle (batch<T, A> const & self, batch_constant<uint16_t , A, Vals...> mask, requires_arch<avx2> req) noexcept
1266+ XSIMD_INLINE batch<T, A> swizzle (batch<T, A> const & self, batch_constant<uint16_t , A, Vals...> const & mask, requires_arch<avx2> req) noexcept
11541267 {
11551268 static_assert (sizeof ...(Vals) == 16 , " Must contain as many uint16_t as can fit in avx register" );
1156- return swizzle (self, mask. as_batch () , req);
1269+ return bitwise_cast<T>( swizzle (bitwise_cast< uint16_t >( self) , mask, req) );
11571270 }
11581271
11591272 template <class A , uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
0 commit comments