Skip to content

Commit 8fbdd6f

Browse files
committed
fix is_cross_lane
1 parent d5ed51f commit 8fbdd6f

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

include/xsimd/arch/common/xsimd_common_swizzle.hpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,49 @@ namespace xsimd
167167
return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value;
168168
}
169169

170+
// 128-bit lane aware cross_impl: checks per 128-bit lane
171+
template <std::size_t I,
172+
std::size_t N,
173+
std::size_t LaneElems,
174+
typename U,
175+
U... Vs>
176+
struct cross_impl128
177+
{
178+
static constexpr std::size_t Vi = static_cast<std::size_t>(get_at<U, I, Vs...>::value);
179+
static constexpr bool curr = ((I / LaneElems) != (static_cast<std::size_t>(Vi) / LaneElems));
180+
static constexpr bool next = cross_impl128<I + 1, N, LaneElems, U, Vs...>::value;
181+
static constexpr bool value = curr || next;
182+
};
183+
template <std::size_t N, std::size_t LaneElems, typename U, U... Vs>
184+
struct cross_impl128<N, N, LaneElems, U, Vs...>
185+
{
186+
static constexpr bool value = false;
187+
};
188+
189+
template <typename ElemT, typename U, U... Vs>
190+
XSIMD_INLINE constexpr bool is_cross_lane_128() noexcept
191+
{
192+
static_assert(sizeof...(Vs) >= 1, "Need at least one lane");
193+
constexpr std::size_t N = sizeof...(Vs);
194+
constexpr std::size_t lane_elems = 16 / sizeof(ElemT);
195+
return cross_impl128<0, N, lane_elems, U, Vs...>::value;
196+
}
197+
198+
// overload accepting an element type first to compute 128-bit lane size
199+
template <typename ElemT, typename U, U... Vs>
200+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
201+
{
202+
static_assert(std::is_integral<U>::value, "swizzle mask values must be integral");
203+
return is_cross_lane_128<ElemT, U, Vs...>();
204+
}
205+
206+
// convenience overload taking element type then integer non-type parameter pack
207+
template <typename ElemT, std::size_t... Vs>
208+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
209+
{
210+
return is_cross_lane_128<ElemT, std::size_t, Vs...>();
211+
}
212+
170213
template <typename T, T... Vs>
171214
XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); }
172215
template <typename T, T... Vs>
@@ -184,7 +227,11 @@ namespace xsimd
184227
template <typename T, class A, T... Vs>
185228
XSIMD_INLINE constexpr bool is_only_from_hi(batch_constant<T, A, Vs...>) noexcept { return detail::is_only_from_hi<T, Vs...>(); }
186229
template <typename T, class A, T... Vs>
187-
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept { return detail::is_cross_lane<Vs...>(); }
230+
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept
231+
{
232+
static_assert(std::is_integral<T>::value, "swizzle mask values must be integral");
233+
return is_cross_lane_128<T, T, Vs...>();
234+
}
188235

189236
} // namespace detail
190237
} // namespace kernel

test/test_batch_manip.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,24 @@ namespace xsimd
5252
static_assert(is_dup_hi<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_hi failed");
5353
static_assert(!is_dup_lo<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_lo on dup_hi");
5454

55-
static_assert(is_cross_lane<0, 1, 0, 1>(), "dup-lo only → crossing");
56-
static_assert(is_cross_lane<2, 3, 2, 3>(), "dup-hi only → crossing");
57-
static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing");
58-
static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing");
59-
static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing");
55+
static_assert(is_cross_lane<double, 0, 1, 0, 1>(), "dup-lo only → crossing");
56+
static_assert(is_cross_lane<double, 2, 3, 2, 3>(), "dup-hi only → crossing");
57+
static_assert(is_cross_lane<double, 0, 3, 3, 3>(), "one low + rest high → crossing");
58+
static_assert(!is_cross_lane<double, 1, 0, 2, 3>(), "mixed low/high → no crossing");
59+
static_assert(!is_cross_lane<double, 0, 1, 2, 3>(), "mixed low/high → no crossing");
60+
// 8-lane 128-bit lane checks (use double/int64 for 2-elements-per-128-bit lanes)
61+
static_assert(is_cross_lane<double, 3, 2, 1, 0, 7, 6, 5, 4>(), "8-lane 128-bit swap → crossing");
62+
static_assert(!is_cross_lane<double, 0, 1, 2, 3, 4, 5, 6, 7>(), "identity 8-lane → no crossing");
63+
static_assert(is_cross_lane<std::uint64_t, 3, 2, 1, 0, 7, 6, 5, 4>(), "8-lane uint64_t swap → crossing");
64+
static_assert(is_cross_lane<std::int32_t, 4, 5, 6, 7, 0, 1, 2, 3>(), "8-lane int32_t swap → crossing");
65+
66+
// Additional compile-time checks for 16-element batches (e.g. float/int32)
67+
static_assert(is_cross_lane<float, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7>(),
68+
"16-lane 128-bit swap → crossing");
69+
static_assert(!is_cross_lane<float, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15>(),
70+
"identity 16-lane → no crossing");
71+
static_assert(is_cross_lane<std::uint32_t, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7>(),
72+
"16-lane uint32_t swap → crossing");
6073
}
6174
}
6275
}

0 commit comments

Comments
 (0)