Skip to content

Commit b0a84e7

Browse files
committed
fixed and improved swizzles
1 parent fe1028d commit b0a84e7

File tree

7 files changed

+555
-388
lines changed

7 files changed

+555
-388
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and *
3+
* Martin Renou *
4+
* Copyright (c) QuantStack *
5+
* Copyright (c) Serge Guelton *
6+
* Copyright (c) Marco Barbone *
7+
* *
8+
* Distributed under the terms of the BSD 3-Clause License. *
9+
* *
10+
* The full license is in the file LICENSE, distributed with this software.*
11+
****************************************************************************/
12+
#ifndef XSIMD_COMMON_SWIZZLE_HPP
13+
#define XSIMD_COMMON_SWIZZLE_HPP
14+
15+
#include <cstddef>
16+
#include <cstdint>
17+
#include <type_traits>
18+
19+
namespace xsimd
20+
{
21+
template <typename T, class A, T... Values>
22+
struct batch_constant;
23+
24+
namespace kernel
25+
{
26+
namespace detail
27+
{
28+
// ────────────────────────────────────────────────────────────────────────
29+
// get_at<I,Values...> → the I-th element of the pack
30+
template <typename T, std::size_t I, T V0, T... Vs>
31+
struct get_at
32+
{
33+
static constexpr T value = get_at<T, I - 1, Vs...>::value;
34+
};
35+
template <typename T, T V0, T... Vs>
36+
struct get_at<T, 0, V0, Vs...>
37+
{
38+
static constexpr T value = V0;
39+
};
40+
41+
// ────────────────────────────────────────────────────────────────────────
42+
// 1) identity_impl
43+
template <std::size_t /*I*/, typename T>
44+
XSIMD_INLINE constexpr bool identity_impl() noexcept { return true; }
45+
template <std::size_t I, typename T, T V0, T... Vs>
46+
XSIMD_INLINE constexpr bool identity_impl() noexcept
47+
{
48+
return V0 == static_cast<T>(I)
49+
&& identity_impl<I + 1, T, Vs...>();
50+
}
51+
52+
// ────────────────────────────────────────────────────────────────────────
53+
// 2) bitmask_impl
54+
template <std::size_t /*I*/, std::size_t /*N*/, typename T>
55+
XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept { return 0u; }
56+
template <std::size_t I, std::size_t N, typename T, T V0, T... Vs>
57+
XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept
58+
{
59+
return (1u << (static_cast<std::uint32_t>(V0) & (N - 1)))
60+
| bitmask_impl<I + 1, N, T, Vs...>();
61+
}
62+
63+
// ────────────────────────────────────────────────────────────────────────
64+
// 3) dup_lo_impl
65+
template <std::size_t I, std::size_t N, typename T,
66+
T... Vs, typename std::enable_if<I == N / 2, int>::type = 0>
67+
XSIMD_INLINE constexpr bool dup_lo_impl() noexcept { return true; }
68+
69+
template <std::size_t I, std::size_t N, typename T,
70+
T... Vs, typename std::enable_if<(I < N / 2), int>::type = 0>
71+
XSIMD_INLINE constexpr bool dup_lo_impl() noexcept
72+
{
73+
return get_at<T, I, Vs...>::value < static_cast<T>(N / 2)
74+
&& get_at<T, I + N / 2, Vs...>::value == get_at<T, I, Vs...>::value
75+
&& dup_lo_impl<I + 1, N, T, Vs...>();
76+
}
77+
78+
// ────────────────────────────────────────────────────────────────────────
79+
// 4) dup_hi_impl
80+
template <std::size_t I, std::size_t N, typename T,
81+
T... Vs, typename std::enable_if<I == N / 2, int>::type = 0>
82+
XSIMD_INLINE constexpr bool dup_hi_impl() noexcept { return true; }
83+
84+
template <std::size_t I, std::size_t N, typename T,
85+
T... Vs, typename std::enable_if<(I < N / 2), int>::type = 0>
86+
XSIMD_INLINE constexpr bool dup_hi_impl() noexcept
87+
{
88+
return get_at<T, I, Vs...>::value >= static_cast<T>(N / 2)
89+
&& get_at<T, I, Vs...>::value < static_cast<T>(N)
90+
&& get_at<T, I + N / 2, Vs...>::value == get_at<T, I, Vs...>::value
91+
&& dup_hi_impl<I + 1, N, T, Vs...>();
92+
}
93+
94+
// ────────────────────────────────────────────────────────────────────────
95+
// 1) helper to get the I-th value from the Vs pack
96+
template <std::size_t I, uint32_t Head, uint32_t... Tail>
97+
struct get_nth_value
98+
{
99+
static constexpr uint32_t value = get_nth_value<I - 1, Tail...>::value;
100+
};
101+
template <uint32_t Head, uint32_t... Tail>
102+
struct get_nth_value<0, Head, Tail...>
103+
{
104+
static constexpr uint32_t value = Head;
105+
};
106+
107+
// ────────────────────────────────────────────────────────────────────────
108+
// 2) recursive cross‐lane test: true if any output‐lane i pulls from the opposite half
109+
template <std::size_t I,
110+
std::size_t N,
111+
std::size_t H,
112+
uint32_t... Vs>
113+
struct cross_impl
114+
{
115+
// does element I cross? (i.e. i<H but V>=H) or (i>=H but V<H)
116+
static constexpr uint32_t Vi = get_nth_value<I, Vs...>::value;
117+
static constexpr bool curr = (I < H ? (Vi >= H) : (Vi < H));
118+
static constexpr bool next = cross_impl<I + 1, N, H, Vs...>::value;
119+
static constexpr bool value = curr || next;
120+
};
121+
template <std::size_t N, std::size_t H, uint32_t... Vs>
122+
struct cross_impl<N, N, H, Vs...>
123+
{
124+
static constexpr bool value = false;
125+
};
126+
template <std::size_t I, std::size_t N, typename T,
127+
T... Vs>
128+
XSIMD_INLINE constexpr bool no_duplicates_impl() noexcept
129+
{
130+
// build the bitmask of (Vs & (N-1)) across all lanes
131+
return detail::bitmask_impl<0, N, T, Vs...>() == ((1u << N) - 1u);
132+
}
133+
template <uint32_t... Vs>
134+
XSIMD_INLINE constexpr bool no_duplicates_v() noexcept
135+
{
136+
// forward to your existing no_duplicates_impl
137+
return no_duplicates_impl<0, sizeof...(Vs), uint32_t, Vs...>();
138+
}
139+
template <uint32_t... Vs>
140+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
141+
{
142+
static_assert(sizeof...(Vs) >= 1, "Need at least one lane");
143+
return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value;
144+
}
145+
template <typename T, T... Vs>
146+
XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); }
147+
template <typename T, T... Vs>
148+
XSIMD_INLINE constexpr bool is_all_different() noexcept
149+
{
150+
return detail::bitmask_impl<0, sizeof...(Vs), T, Vs...>() == ((1u << sizeof...(Vs)) - 1);
151+
}
152+
153+
template <typename T, T... Vs>
154+
XSIMD_INLINE constexpr bool is_dup_lo() noexcept { return detail::dup_lo_impl<0, sizeof...(Vs), T, Vs...>(); }
155+
template <typename T, T... Vs>
156+
XSIMD_INLINE constexpr bool is_dup_hi() noexcept { return detail::dup_hi_impl<0, sizeof...(Vs), T, Vs...>(); }
157+
template <typename T, class A, T... Vs>
158+
XSIMD_INLINE constexpr bool is_identity(batch_constant<T, A, Vs...>) noexcept { return is_identity<T, Vs...>(); }
159+
template <typename T, class A, T... Vs>
160+
XSIMD_INLINE constexpr bool is_all_different(batch_constant<T, A, Vs...>) noexcept { return is_all_different<T, Vs...>(); }
161+
template <typename T, class A, T... Vs>
162+
XSIMD_INLINE constexpr bool is_dup_lo(batch_constant<T, A, Vs...>) noexcept { return is_dup_lo<T, Vs...>(); }
163+
template <typename T, class A, T... Vs>
164+
XSIMD_INLINE constexpr bool is_dup_hi(batch_constant<T, A, Vs...>) noexcept { return is_dup_hi<T, Vs...>(); }
165+
template <typename T, class A, T... Vs>
166+
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept { return detail::is_cross_lane<Vs...>(); }
167+
template <typename T, class A, T... Vs>
168+
XSIMD_INLINE constexpr bool no_duplicates(batch_constant<T, A, Vs...>) noexcept { return no_duplicates_impl<0, sizeof...(Vs), T, Vs...>(); }
169+
// ────────────────────────────────────────────────────────────────────────
170+
// compile-time tests (identity, all-different, dup-lo, dup-hi)
171+
// 8-lane identity
172+
static_assert(is_identity<std::uint32_t, 0, 1, 2, 3, 4, 5, 6, 7>(), "identity failed");
173+
// 8-lane reverse is all-different but not identity
174+
static_assert(is_all_different<std::uint32_t, 7, 6, 5, 4, 3, 2, 1, 0>(), "all-diff failed");
175+
static_assert(!is_identity<std::uint32_t, 7, 6, 5, 4, 3, 2, 1, 0>(), "identity on reverse");
176+
// 8-lane dup-lo (repeat 0..3 twice)
177+
static_assert(is_dup_lo<std::uint32_t, 0, 1, 2, 3, 0, 1, 2, 3>(), "dup_lo failed");
178+
static_assert(!is_dup_hi<std::uint32_t, 0, 1, 2, 3, 0, 1, 2, 3>(), "dup_hi on dup_lo");
179+
// 8-lane dup-hi (repeat 4..7 twice)
180+
static_assert(is_dup_hi<std::uint32_t, 4, 5, 6, 7, 4, 5, 6, 7>(), "dup_hi failed");
181+
static_assert(!is_dup_lo<std::uint32_t, 4, 5, 6, 7, 4, 5, 6, 7>(), "dup_lo on dup_hi");
182+
// ────────────────────────────────────────────────────────────────────────
183+
// 4-lane identity
184+
static_assert(is_identity<std::uint32_t, 0, 1, 2, 3>(), "4-lane identity failed");
185+
// 4-lane reverse all-different but not identity
186+
static_assert(is_all_different<std::uint32_t, 3, 2, 1, 0>(), "4-lane all-diff failed");
187+
static_assert(!is_identity<std::uint32_t, 3, 2, 1, 0>(), "4-lane identity on reverse");
188+
// 4-lane dup-lo (repeat 0..1 twice)
189+
static_assert(is_dup_lo<std::uint32_t, 0, 1, 0, 1>(), "4-lane dup_lo failed");
190+
static_assert(!is_dup_hi<std::uint32_t, 0, 1, 0, 1>(), "4-lane dup_hi on dup_lo");
191+
// 4-lane dup-hi (repeat 2..3 twice)
192+
static_assert(is_dup_hi<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_hi failed");
193+
static_assert(!is_dup_lo<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_lo on dup_hi");
194+
195+
static_assert(is_cross_lane<0, 1, 0, 1>(), "dup-lo only → crossing");
196+
static_assert(is_cross_lane<2, 3, 2, 3>(), "dup-hi only → crossing");
197+
static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing");
198+
static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing");
199+
static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing");
200+
201+
static_assert(no_duplicates_v<0, 1, 2, 3>(), "N=4: [0,1,2,3] → distinct");
202+
static_assert(!no_duplicates_v<0, 1, 2, 2>(), "N=4: [0,1,2,2] → dup");
203+
204+
static_assert(no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 7>(), "N=8: [0..7] → distinct");
205+
static_assert(!no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 0>(), "N=8: last repeats 0");
206+
207+
// ────────────────────────────────────────────────────────────────────────
208+
// ────── log2 for powers of 2 ──────
209+
template <std::size_t N>
210+
struct log2_c
211+
{
212+
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be power of 2");
213+
static constexpr std::size_t value = 1 + log2_c<N / 2>::value;
214+
};
215+
template <>
216+
struct log2_c<1>
217+
{
218+
static constexpr std::size_t value = 0;
219+
};
220+
221+
// ────── Recursive encoder ──────
222+
template <std::size_t I, std::size_t N, std::size_t SHIFT, uint32_t... Values>
223+
struct shuffle_impl
224+
{
225+
static constexpr uint32_t value = (get_nth_value<I, Values...>::value << (I * SHIFT)) | shuffle_impl<I + 1, N, SHIFT, Values...>::value;
226+
};
227+
template <std::size_t N, std::size_t SHIFT, uint32_t... Values>
228+
struct shuffle_impl<N, N, SHIFT, Values...>
229+
{
230+
static constexpr uint32_t value = 0;
231+
};
232+
template <uint32_t... Values>
233+
XSIMD_INLINE constexpr std::uint32_t shuffle() noexcept
234+
{
235+
return shuffle_impl<0,
236+
sizeof...(Values),
237+
log2_c<sizeof...(Values)>::value,
238+
Values...>::value;
239+
}
240+
241+
template <uint32_t... Values>
242+
XSIMD_INLINE constexpr std::uint32_t mod_shuffle() noexcept
243+
{
244+
return shuffle<(Values % sizeof...(Values))...>();
245+
}
246+
} // namespace detail
247+
} // namespace kernel
248+
} // namespace xsimd
249+
250+
#endif // XSIMD_COMMON_SWIZZLE_HPP

0 commit comments

Comments
 (0)