Skip to content

Commit f6b6fa5

Browse files
Add overload for std::integral_constant to batch_constant operators
And use them where relevant.
1 parent 9a360a2 commit f6b6fa5

File tree

5 files changed

+54
-17
lines changed

5 files changed

+54
-17
lines changed

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ namespace xsimd
16131613
}
16141614
return split;
16151615
}
1616-
constexpr auto lane_mask = mask % make_batch_constant<uint32_t, (mask.size / 2), A>();
1616+
constexpr auto lane_mask = mask % std::integral_constant<uint32_t, (mask.size / 2)>();
16171617
XSIMD_IF_CONSTEXPR(detail::is_only_from_lo(mask))
16181618
{
16191619
__m256 broadcast = _mm256_permute2f128_ps(self, self, 0x00); // [low | low]
@@ -1632,15 +1632,15 @@ namespace xsimd
16321632
__m256 swapped = _mm256_permute2f128_ps(self, self, 0x01); // [high | low]
16331633

16341634
// normalize mask taking modulo 4
1635-
constexpr auto half_mask = mask % make_batch_constant<uint32_t, 4, A>();
1635+
constexpr auto half_mask = mask % std::integral_constant<uint32_t, 4>();
16361636

16371637
// permute within each lane
16381638
__m256 r0 = _mm256_permutevar_ps(self, half_mask.as_batch());
16391639
__m256 r1 = _mm256_permutevar_ps(swapped, half_mask.as_batch());
16401640

16411641
// select lane by the mask index divided by 4
16421642
constexpr auto lane = batch_constant<uint32_t, A, 0, 0, 0, 0, 1, 1, 1, 1> {};
1643-
constexpr int lane_idx = ((mask / make_batch_constant<uint32_t, 4, A>()) != lane).mask();
1643+
constexpr int lane_idx = ((mask / std::integral_constant<uint32_t, 4>()) != lane).mask();
16441644

16451645
return _mm256_blend_ps(r0, r1, lane_idx);
16461646
}
@@ -1681,7 +1681,7 @@ namespace xsimd
16811681

16821682
// select lane by the mask index divided by 2
16831683
constexpr auto lane = batch_constant<uint64_t, A, 0, 0, 1, 1> {};
1684-
constexpr int lane_idx = ((mask / make_batch_constant<uint64_t, 2, A>()) != lane).mask();
1684+
constexpr int lane_idx = ((mask / std::integral_constant<uint64_t, 2>()) != lane).mask();
16851685

16861686
// blend the two permutes
16871687
return _mm256_blend_pd(r0, r1, lane_idx);

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ namespace xsimd
13321332
return self;
13331333
}
13341334

1335-
constexpr auto lane_mask = mask % make_batch_constant<uint8_t, (mask.size / 2), A>();
1335+
constexpr auto lane_mask = mask % std::integral_constant<uint8_t, (mask.size / 2)>();
13361336

13371337
XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
13381338
{
@@ -1409,7 +1409,7 @@ namespace xsimd
14091409
}
14101410
XSIMD_IF_CONSTEXPR(!detail::is_cross_lane(mask))
14111411
{
1412-
constexpr auto lane_mask = mask % make_batch_constant<uint32_t, (mask.size / 2), A>();
1412+
constexpr auto lane_mask = mask % std::integral_constant<uint32_t, (mask.size / 2)>();
14131413
// Cheaper intrinsics when not crossing lanes
14141414
// Contrary to the uint64_t version, the limits of 8 bits for the immediate constant
14151415
// cannot make different permutations across lanes

include/xsimd/arch/xsimd_sse2.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ namespace xsimd
20742074
__m128i hi = _mm_unpackhi_epi64(hil, hih);
20752075

20762076
// mask to choose the right lane
2077-
constexpr auto blend_mask = mask < make_batch_constant<uint16_t, 4, A>();
2077+
constexpr auto blend_mask = mask < std::integral_constant<uint16_t, 4>();
20782078

20792079
// blend the two permutes
20802080
return select(blend_mask, batch<uint16_t, A>(lo), batch<uint16_t, A>(hi));

include/xsimd/types/xsimd_batch_constant.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,16 @@ namespace xsimd
316316
}
317317

318318
public:
319-
#define MAKE_BINARY_OP(OP, NAME) \
320-
template <T... OtherValues> \
321-
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
322-
{ \
323-
return apply<NAME<void>>(*this, other); \
319+
#define MAKE_BINARY_OP(OP, NAME) \
320+
template <T... OtherValues> \
321+
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
322+
{ \
323+
return apply<NAME<void>>(*this, other); \
324+
} \
325+
template <T OtherValue> \
326+
constexpr batch_constant<T, A, (Values OP OtherValue)...> operator OP(std::integral_constant<T, OtherValue>) const \
327+
{ \
328+
return {}; \
324329
}
325330

326331
MAKE_BINARY_OP(+, std::plus)
@@ -350,11 +355,16 @@ namespace xsimd
350355
return apply_bool<F, std::tuple<std::integral_constant<T, Values>...>, std::tuple<std::integral_constant<T, OtherValues>...>>(std::make_index_sequence<sizeof...(Values)>());
351356
}
352357

353-
#define MAKE_BINARY_BOOL_OP(OP, NAME) \
354-
template <T... OtherValues> \
355-
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
356-
{ \
357-
return apply_bool<NAME<void>>(*this, other); \
358+
#define MAKE_BINARY_BOOL_OP(OP, NAME) \
359+
template <T... OtherValues> \
360+
constexpr auto operator OP(batch_constant<T, A, OtherValues...> other) const \
361+
{ \
362+
return apply_bool<NAME<void>>(*this, other); \
363+
} \
364+
template <T OtherValue> \
365+
constexpr batch_bool_constant<T, A, (Values OP OtherValue)...> operator OP(std::integral_constant<T, OtherValue>) const \
366+
{ \
367+
return {}; \
358368
}
359369

360370
MAKE_BINARY_BOOL_OP(==, std::equal_to)

test/test_batch_constant.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,43 +110,64 @@ struct constant_batch_test
110110
{
111111
constexpr auto n12 = xsimd::make_batch_constant<value_type, constant<12>, arch_type>();
112112
constexpr auto n3 = xsimd::make_batch_constant<value_type, constant<3>, arch_type>();
113+
constexpr std::integral_constant<value_type, 3> c3;
113114

114115
constexpr auto n12_add_n3 = n12 + n3;
115116
constexpr auto n15 = xsimd::make_batch_constant<value_type, constant<15>, arch_type>();
116117
static_assert(std::is_same<decltype(n12_add_n3), decltype(n15)>::value, "n12 + n3 == n15");
118+
constexpr auto n12_add_c3 = n12 + c3;
119+
static_assert(std::is_same<decltype(n12_add_c3), decltype(n15)>::value, "n12 + c3 == n15");
117120

118121
constexpr auto n12_sub_n3 = n12 - n3;
119122
constexpr auto n9 = xsimd::make_batch_constant<value_type, constant<9>, arch_type>();
120123
static_assert(std::is_same<decltype(n12_sub_n3), decltype(n9)>::value, "n12 - n3 == n9");
124+
constexpr auto n12_sub_c3 = n12 - c3;
125+
static_assert(std::is_same<decltype(n12_sub_c3), decltype(n9)>::value, "n12 - c3 == n9");
121126

122127
constexpr auto n12_mul_n3 = n12 * n3;
123128
constexpr auto n36 = xsimd::make_batch_constant<value_type, constant<36>, arch_type>();
124129
static_assert(std::is_same<decltype(n12_mul_n3), decltype(n36)>::value, "n12 * n3 == n36");
130+
constexpr auto n12_mul_c3 = n12 * c3;
131+
static_assert(std::is_same<decltype(n12_mul_c3), decltype(n36)>::value, "n12 - c3 == n36");
125132

126133
constexpr auto n12_div_n3 = n12 / n3;
127134
constexpr auto n4 = xsimd::make_batch_constant<value_type, constant<4>, arch_type>();
128135
static_assert(std::is_same<decltype(n12_div_n3), decltype(n4)>::value, "n12 / n3 == n4");
136+
constexpr auto n12_div_c3 = n12 / c3;
137+
static_assert(std::is_same<decltype(n12_div_c3), decltype(n4)>::value, "n12 / c3 == n4");
129138

130139
constexpr auto n12_mod_n3 = n12 % n3;
131140
constexpr auto n0 = xsimd::make_batch_constant<value_type, constant<0>, arch_type>();
132141
static_assert(std::is_same<decltype(n12_mod_n3), decltype(n0)>::value, "n12 % n3 == n0");
142+
constexpr auto n12_mod_c3 = n12 % c3;
143+
static_assert(std::is_same<decltype(n12_mod_c3), decltype(n0)>::value, "n12 % c3 == n0");
133144

134145
constexpr auto n12_land_n3 = n12 & n3;
135146
static_assert(std::is_same<decltype(n12_land_n3), decltype(n0)>::value, "n12 & n3 == n0");
147+
constexpr auto n12_land_c3 = n12 & c3;
148+
static_assert(std::is_same<decltype(n12_land_c3), decltype(n0)>::value, "n12 & c3 == n0");
136149

137150
constexpr auto n12_lor_n3 = n12 | n3;
138151
static_assert(std::is_same<decltype(n12_lor_n3), decltype(n15)>::value, "n12 | n3 == n15");
152+
constexpr auto n12_lor_c3 = n12 | c3;
153+
static_assert(std::is_same<decltype(n12_lor_c3), decltype(n15)>::value, "n12 | c3 == n15");
139154

140155
constexpr auto n12_lxor_n3 = n12 ^ n3;
141156
static_assert(std::is_same<decltype(n12_lxor_n3), decltype(n15)>::value, "n12 ^ n3 == n15");
157+
constexpr auto n12_lxor_c3 = n12 ^ c3;
158+
static_assert(std::is_same<decltype(n12_lxor_c3), decltype(n15)>::value, "n12 ^ c3 == n15");
142159

143160
constexpr auto n96 = xsimd::make_batch_constant<value_type, constant<96>, arch_type>();
144161
constexpr auto n12_lshift_n3 = n12 << n3;
145162
static_assert(std::is_same<decltype(n12_lshift_n3), decltype(n96)>::value, "n12 << n3 == n96");
163+
constexpr auto n12_lshift_c3 = n12 << c3;
164+
static_assert(std::is_same<decltype(n12_lshift_c3), decltype(n96)>::value, "n12 << c3 == n96");
146165

147166
constexpr auto n1 = xsimd::make_batch_constant<value_type, constant<1>, arch_type>();
148167
constexpr auto n12_rshift_n3 = n12 >> n3;
149168
static_assert(std::is_same<decltype(n12_rshift_n3), decltype(n1)>::value, "n12 >> n3 == n1");
169+
constexpr auto n12_rshift_c3 = n12 >> c3;
170+
static_assert(std::is_same<decltype(n12_rshift_c3), decltype(n1)>::value, "n12 >> c3 == n1");
150171

151172
constexpr auto n12_uadd = +n12;
152173
static_assert(std::is_same<decltype(n12_uadd), decltype(n12)>::value, "+n12 == n12");
@@ -167,21 +188,27 @@ struct constant_batch_test
167188

168189
static_assert(std::is_same<decltype(n12 == n12), true_batch_type>::value, "n12 == n12");
169190
static_assert(std::is_same<decltype(n12 == n3), false_batch_type>::value, "n12 == n3");
191+
static_assert(std::is_same<decltype(n12 == c3), false_batch_type>::value, "n12 == c3");
170192

171193
static_assert(std::is_same<decltype(n12 != n12), false_batch_type>::value, "n12 != n12");
172194
static_assert(std::is_same<decltype(n12 != n3), true_batch_type>::value, "n12 != n3");
195+
static_assert(std::is_same<decltype(n12 != c3), true_batch_type>::value, "n12 != c3");
173196

174197
static_assert(std::is_same<decltype(n12 < n12), false_batch_type>::value, "n12 < n12");
175198
static_assert(std::is_same<decltype(n12 < n3), false_batch_type>::value, "n12 < n3");
199+
static_assert(std::is_same<decltype(n12 < c3), false_batch_type>::value, "n12 < c3");
176200

177201
static_assert(std::is_same<decltype(n12 > n12), false_batch_type>::value, "n12 > n12");
178202
static_assert(std::is_same<decltype(n12 > n3), true_batch_type>::value, "n12 > n3");
203+
static_assert(std::is_same<decltype(n12 > c3), true_batch_type>::value, "n12 > c3");
179204

180205
static_assert(std::is_same<decltype(n12 <= n12), true_batch_type>::value, "n12 <= n12");
181206
static_assert(std::is_same<decltype(n12 <= n3), false_batch_type>::value, "n12 <= n3");
207+
static_assert(std::is_same<decltype(n12 <= c3), false_batch_type>::value, "n12 <= c3");
182208

183209
static_assert(std::is_same<decltype(n12 >= n12), true_batch_type>::value, "n12 >= n12");
184210
static_assert(std::is_same<decltype(n12 >= n3), true_batch_type>::value, "n12 >= n3");
211+
static_assert(std::is_same<decltype(n12 >= c3), true_batch_type>::value, "n12 >= c3");
185212
}
186213
};
187214

0 commit comments

Comments
 (0)