Skip to content

Commit 7eb009b

Browse files
Provide and extra as_batch() method for batch_constant and as_bool_batch() batch_bool_constant
They come in addition to the implicit conversion operator which is cumbersome to use.
1 parent c63af92 commit 7eb009b

File tree

8 files changed

+60
-25
lines changed

8 files changed

+60
-25
lines changed

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,8 +1517,8 @@ namespace xsimd
15171517
batch_constant<uint32_t, A, (V0 % 4), (V1 % 4), (V2 % 4), (V3 % 4), (V4 % 4), (V5 % 4), (V6 % 4), (V7 % 4)> half_mask;
15181518

15191519
// permute within each lane
1520-
__m256 r0 = _mm256_permutevar_ps(low_low, (batch<uint32_t, A>)half_mask);
1521-
__m256 r1 = _mm256_permutevar_ps(hi_hi, (batch<uint32_t, A>)half_mask);
1520+
__m256 r0 = _mm256_permutevar_ps(low_low, half_mask.as_batch());
1521+
__m256 r1 = _mm256_permutevar_ps(hi_hi, half_mask.as_batch());
15221522

15231523
// mask to choose the right lane
15241524
batch_bool_constant<uint32_t, A, (V0 >= 4), (V1 >= 4), (V2 >= 4), (V3 >= 4), (V4 >= 4), (V5 >= 4), (V6 >= 4), (V7 >= 4)> blend_mask;
@@ -1542,8 +1542,8 @@ namespace xsimd
15421542
batch_constant<uint64_t, A, (V0 % 2) * -1, (V1 % 2) * -1, (V2 % 2) * -1, (V3 % 2) * -1> half_mask;
15431543

15441544
// permute within each lane
1545-
__m256d r0 = _mm256_permutevar_pd(low_low, (batch<uint64_t, A>)half_mask);
1546-
__m256d r1 = _mm256_permutevar_pd(hi_hi, (batch<uint64_t, A>)half_mask);
1545+
__m256d r0 = _mm256_permutevar_pd(low_low, half_mask.as_batch());
1546+
__m256d r1 = _mm256_permutevar_pd(hi_hi, half_mask.as_batch());
15471547

15481548
// mask to choose the right lane
15491549
batch_bool_constant<uint64_t, A, (V0 >= 2), (V1 >= 2), (V2 >= 2), (V3 >= 2)> blend_mask;

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ namespace xsimd
914914
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
915915
inline batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
916916
{
917-
return _mm256_permutevar8x32_ps(self, (batch<uint32_t, A>)mask);
917+
return _mm256_permutevar8x32_ps(self, mask.as_batch());
918918
}
919919

920920
template <class A, uint64_t V0, uint64_t V1, uint64_t V2, uint64_t V3>
@@ -938,7 +938,7 @@ namespace xsimd
938938
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
939939
inline batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept
940940
{
941-
return _mm256_permutevar8x32_epi32(self, (batch<uint32_t, A>)mask);
941+
return _mm256_permutevar8x32_epi32(self, mask.as_batch());
942942
}
943943
template <class A, uint32_t V0, uint32_t V1, uint32_t V2, uint32_t V3, uint32_t V4, uint32_t V5, uint32_t V6, uint32_t V7>
944944
inline batch<int32_t, A> swizzle(batch<int32_t, A> const& self, batch_constant<uint32_t, A, V0, V1, V2, V3, V4, V5, V6, V7> mask, requires_arch<avx2>) noexcept

include/xsimd/arch/xsimd_avx512bw.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -619,25 +619,25 @@ namespace xsimd
619619
template <class A, uint16_t... Vs>
620620
inline batch<uint16_t, A> swizzle(batch<uint16_t, A> const& self, batch_constant<uint16_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
621621
{
622-
return swizzle(self, (batch<uint16_t, A>)mask, avx512bw {});
622+
return swizzle(self, mask.as_batch(), avx512bw {});
623623
}
624624

625625
template <class A, uint16_t... Vs>
626626
inline batch<int16_t, A> swizzle(batch<int16_t, A> const& self, batch_constant<uint16_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
627627
{
628-
return swizzle(self, (batch<uint16_t, A>)mask, avx512bw {});
628+
return swizzle(self, mask.as_batch(), avx512bw {});
629629
}
630630

631631
template <class A, uint8_t... Vs>
632632
inline batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch_constant<uint8_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
633633
{
634-
return swizzle(self, (batch<uint8_t, A>)mask, avx512bw {});
634+
return swizzle(self, mask.as_batch(), avx512bw {});
635635
}
636636

637637
template <class A, uint8_t... Vs>
638638
inline batch<int8_t, A> swizzle(batch<int8_t, A> const& self, batch_constant<uint8_t, A, Vs...> mask, requires_arch<avx512bw>) noexcept
639639
{
640-
return swizzle(self, (batch<uint8_t, A>)mask, avx512bw {});
640+
return swizzle(self, mask.as_batch(), avx512bw {});
641641
}
642642

643643
// zip_hi

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ namespace xsimd
14231423
inline T reduce_max(batch<T, A> const& self, requires_arch<avx512f>) noexcept
14241424
{
14251425
constexpr batch_constant<uint64_t, A, 5, 6, 7, 8, 0, 0, 0, 0> mask;
1426-
batch<T, A> step = _mm512_permutexvar_epi64((batch<uint64_t, A>)mask, self);
1426+
batch<T, A> step = _mm512_permutexvar_epi64(mask.as_batch(), self);
14271427
batch<T, A> acc = max(self, step);
14281428
__m256i low = _mm512_castsi512_si256(acc);
14291429
return reduce_max(batch<T, avx2>(low));
@@ -1434,7 +1434,7 @@ namespace xsimd
14341434
inline T reduce_min(batch<T, A> const& self, requires_arch<avx512f>) noexcept
14351435
{
14361436
constexpr batch_constant<uint64_t, A, 5, 6, 7, 8, 0, 0, 0, 0> mask;
1437-
batch<T, A> step = _mm512_permutexvar_epi64((batch<uint64_t, A>)mask, self);
1437+
batch<T, A> step = _mm512_permutexvar_epi64(mask.as_batch(), self);
14381438
batch<T, A> acc = min(self, step);
14391439
__m256i low = _mm512_castsi512_si256(acc);
14401440
return reduce_min(batch<T, avx2>(low));
@@ -1919,37 +1919,37 @@ namespace xsimd
19191919
template <class A, uint32_t... Vs>
19201920
inline batch<float, A> swizzle(batch<float, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
19211921
{
1922-
return swizzle(self, (batch<uint32_t, A>)mask, avx512f {});
1922+
return swizzle(self, mask.as_batch(), avx512f {});
19231923
}
19241924

19251925
template <class A, uint64_t... Vs>
19261926
inline batch<double, A> swizzle(batch<double, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
19271927
{
1928-
return swizzle(self, (batch<uint64_t, A>)mask, avx512f {});
1928+
return swizzle(self, mask.as_batch(), avx512f {});
19291929
}
19301930

19311931
template <class A, uint64_t... Vs>
19321932
inline batch<uint64_t, A> swizzle(batch<uint64_t, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
19331933
{
1934-
return swizzle(self, (batch<uint64_t, A>)mask, avx512f {});
1934+
return swizzle(self, mask.as_batch(), avx512f {});
19351935
}
19361936

19371937
template <class A, uint64_t... Vs>
19381938
inline batch<int64_t, A> swizzle(batch<int64_t, A> const& self, batch_constant<uint64_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
19391939
{
1940-
return swizzle(self, (batch<uint64_t, A>)mask, avx512f {});
1940+
return swizzle(self, mask.as_batch(), avx512f {});
19411941
}
19421942

19431943
template <class A, uint32_t... Vs>
19441944
inline batch<uint32_t, A> swizzle(batch<uint32_t, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
19451945
{
1946-
return swizzle(self, (batch<uint32_t, A>)mask, avx512f {});
1946+
return swizzle(self, mask.as_batch(), avx512f {});
19471947
}
19481948

19491949
template <class A, uint32_t... Vs>
19501950
inline batch<int32_t, A> swizzle(batch<int32_t, A> const& self, batch_constant<uint32_t, A, Vs...> mask, requires_arch<avx512f>) noexcept
19511951
{
1952-
return swizzle(self, (batch<uint32_t, A>)mask, avx512f {});
1952+
return swizzle(self, mask.as_batch(), avx512f {});
19531953
}
19541954

19551955
namespace detail

include/xsimd/arch/xsimd_ssse3.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ namespace xsimd
145145
constexpr batch_constant<uint8_t, A, 2 * V0, 2 * V0 + 1, 2 * V1, 2 * V1 + 1, 2 * V2, 2 * V2 + 1, 2 * V3, 2 * V3 + 1,
146146
2 * V4, 2 * V4 + 1, 2 * V5, 2 * V5 + 1, 2 * V6, 2 * V6 + 1, 2 * V7, 2 * V7 + 1>
147147
mask8;
148-
return _mm_shuffle_epi8(self, (batch<uint8_t, A>)mask8);
148+
return _mm_shuffle_epi8(self, mask8.as_batch());
149149
}
150150

151151
template <class A, uint16_t V0, uint16_t V1, uint16_t V2, uint16_t V3, uint16_t V4, uint16_t V5, uint16_t V6, uint16_t V7>
@@ -158,14 +158,14 @@ namespace xsimd
158158
uint8_t V8, uint8_t V9, uint8_t V10, uint8_t V11, uint8_t V12, uint8_t V13, uint8_t V14, uint8_t V15>
159159
inline batch<uint8_t, A> swizzle(batch<uint8_t, A> const& self, batch_constant<uint8_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask, requires_arch<ssse3>) noexcept
160160
{
161-
return swizzle(self, (batch<uint8_t, A>)mask, ssse3 {});
161+
return swizzle(self, mask.as_batch(), ssse3 {});
162162
}
163163

164164
template <class A, uint8_t V0, uint8_t V1, uint8_t V2, uint8_t V3, uint8_t V4, uint8_t V5, uint8_t V6, uint8_t V7,
165165
uint8_t V8, uint8_t V9, uint8_t V10, uint8_t V11, uint8_t V12, uint8_t V13, uint8_t V14, uint8_t V15>
166166
inline batch<int8_t, A> swizzle(batch<int8_t, A> const& self, batch_constant<uint8_t, A, V0, V1, V2, V3, V4, V5, V6, V7, V8, V9, V10, V11, V12, V13, V14, V15> mask, requires_arch<ssse3>) noexcept
167167
{
168-
return swizzle(self, (batch<uint8_t, A>)mask, ssse3 {});
168+
return swizzle(self, mask.as_batch(), ssse3 {});
169169
}
170170

171171
}

include/xsimd/arch/xsimd_sve.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ namespace xsimd
742742
inline batch<T, A> swizzle(batch<T, A> const& arg, batch_constant<I, A, idx...> indices, requires_arch<sve>) noexcept
743743
{
744744
static_assert(batch<T, A>::size == sizeof...(idx), "invalid swizzle indices");
745-
return swizzle(arg, (batch<I, A>)indices, sve {});
745+
return swizzle(arg, indices.as_batch(), sve {});
746746
}
747747

748748
template <class A, class T, class I, I... idx>
@@ -751,7 +751,7 @@ namespace xsimd
751751
requires_arch<sve>) noexcept
752752
{
753753
static_assert(batch<std::complex<T>, A>::size == sizeof...(idx), "invalid swizzle indices");
754-
return swizzle(arg, (batch<I, A>)indices, sve {});
754+
return swizzle(arg, indices.as_batch(), sve {});
755755
}
756756

757757
/*************

include/xsimd/types/xsimd_batch_constant.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,15 @@ namespace xsimd
3434
static_assert(sizeof...(Values) == batch_type::size, "consistent batch size");
3535

3636
public:
37-
constexpr operator batch_bool<T, A>() const noexcept { return { Values... }; }
37+
/**
38+
* @brief Generate a batch of @p batch_type from this @p batch_bool_constant
39+
*/
40+
constexpr batch_type as_batch_bool() const noexcept { return { Values... }; }
41+
42+
/**
43+
* @brief Generate a batch of @p batch_type from this @p batch_bool_constant
44+
*/
45+
constexpr operator batch_type() const noexcept { return as_batch_bool(); }
3846

3947
constexpr bool get(size_t i) const noexcept
4048
{
@@ -130,7 +138,12 @@ namespace xsimd
130138
/**
131139
* @brief Generate a batch of @p batch_type from this @p batch_constant
132140
*/
133-
inline operator batch_type() const noexcept { return { Values... }; }
141+
inline batch_type as_batch() const noexcept { return { Values... }; }
142+
143+
/**
144+
* @brief Generate a batch of @p batch_type from this @p batch_constant
145+
*/
146+
inline operator batch_type() const noexcept { return as_batch(); }
134147

135148
/**
136149
* @brief Get the @p i th element of this @p batch_constant

test/test_batch_constant.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ struct constant_batch_test
4545
CHECK_BATCH_EQ((batch_type)b, expected);
4646
}
4747

48+
void test_cast() const
49+
{
50+
constexpr auto cst_b = xsimd::make_batch_constant<value_type, arch_type, generator>();
51+
auto b0 = cst_b.as_batch();
52+
auto b1 = (batch_type)cst_b;
53+
CHECK_BATCH_EQ(b0, b1);
54+
// The actual values are already tested in test_init_from_generator
55+
}
56+
4857
struct arange
4958
{
5059
static constexpr value_type get(size_t index, size_t /*size*/)
@@ -135,6 +144,8 @@ TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
135144
constant_batch_test<B> Test;
136145
SUBCASE("init_from_generator") { Test.test_init_from_generator(); }
137146

147+
SUBCASE("as_batch") { Test.test_cast(); }
148+
138149
SUBCASE("init_from_generator_arange")
139150
{
140151
Test.test_init_from_generator_arange();
@@ -216,6 +227,15 @@ struct constant_bool_batch_test
216227
}
217228
};
218229

230+
void test_cast() const
231+
{
232+
constexpr auto all_true = xsimd::make_batch_bool_constant<value_type, arch_type, constant<true>>();
233+
auto b0 = all_true.as_batch_bool();
234+
auto b1 = (batch_bool_type)all_true;
235+
CHECK_BATCH_EQ(b0, batch_bool_type(true));
236+
CHECK_BATCH_EQ(b1, batch_bool_type(true));
237+
}
238+
219239
void test_ops() const
220240
{
221241
constexpr auto all_true = xsimd::make_batch_bool_constant<value_type, arch_type, constant<true>>();
@@ -252,6 +272,8 @@ TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
252272
constant_bool_batch_test<B> Test;
253273
SUBCASE("init_from_generator") { Test.test_init_from_generator(); }
254274

275+
SUBCASE("as_batch") { Test.test_cast(); }
276+
255277
SUBCASE("init_from_generator_split")
256278
{
257279
Test.test_init_from_generator_split();

0 commit comments

Comments
 (0)