Skip to content

Commit 286aae0

Browse files
committed
bench: add benchmark for multi-axis roll
1 parent eec6188 commit 286aae0

File tree

2 files changed

+387
-0
lines changed

2 files changed

+387
-0
lines changed

benchmark/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ set(XTENSOR_BENCHMARK
111111
benchmark_math.cpp
112112
benchmark_random.cpp
113113
benchmark_reducer.cpp
114+
benchmark_roll.cpp
114115
benchmark_views.cpp
115116
benchmark_xshape.cpp
116117
benchmark_view_access.cpp

benchmark/benchmark_roll.cpp

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,386 @@
1+
/***************************************************************************
2+
* Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht *
3+
* *
4+
* Distributed under the terms of the BSD 3-Clause License. *
5+
* *
6+
* The full license is in the file LICENSE, distributed with this software. *
7+
****************************************************************************/
8+
9+
#include <benchmark/benchmark.h>
10+
11+
#include "xtensor/containers/xtensor.hpp"
12+
#include "xtensor/core/xmath.hpp"
13+
#include "xtensor/generators/xrandom.hpp"
14+
#include "xtensor/misc/xmanipulation.hpp"
15+
16+
namespace xt
17+
{
18+
namespace roll_bench
19+
{
20+
namespace detail
21+
{
22+
/*********************************************
23+
* Correctness verification helper
24+
*********************************************/
25+
26+
template <class T>
27+
bool verify_correctness(const T& multi_result, const T& sequential_result)
28+
{
29+
return xt::allclose(multi_result, sequential_result);
30+
}
31+
32+
/*********************************************
33+
* 2D roll benchmarks (2 axes)
34+
* shift = size * ratio for each axis
35+
*********************************************/
36+
37+
inline void roll_2d_sequential(benchmark::State& state, std::size_t h, std::size_t w, double ratio)
38+
{
39+
xt::xtensor<double, 2> input = xt::random::rand<double>({h, w});
40+
auto shift_h = static_cast<std::ptrdiff_t>(h * ratio);
41+
auto shift_w = static_cast<std::ptrdiff_t>(w * ratio);
42+
43+
for (auto _ : state)
44+
{
45+
xt::xtensor<double, 2> temp = xt::roll(input, shift_h, 0);
46+
xt::xtensor<double, 2> result = xt::roll(temp, shift_w, 1);
47+
benchmark::DoNotOptimize(result.data());
48+
benchmark::ClobberMemory();
49+
}
50+
51+
state.SetItemsProcessed(state.iterations() * h * w);
52+
state.SetBytesProcessed(state.iterations() * h * w * sizeof(double));
53+
}
54+
55+
inline void roll_2d_multi(benchmark::State& state, std::size_t h, std::size_t w, double ratio)
56+
{
57+
xt::xtensor<double, 2> input = xt::random::rand<double>({h, w});
58+
auto shift_h = static_cast<std::ptrdiff_t>(h * ratio);
59+
auto shift_w = static_cast<std::ptrdiff_t>(w * ratio);
60+
61+
// Verify correctness once
62+
{
63+
xt::xtensor<double, 2> temp = xt::roll(input, shift_h, 0);
64+
xt::xtensor<double, 2> sequential = xt::roll(temp, shift_w, 1);
65+
auto multi = xt::roll(input, {shift_h, shift_w}, {0, 1});
66+
if (!verify_correctness(multi, sequential))
67+
{
68+
state.SkipWithError("Correctness check failed!");
69+
return;
70+
}
71+
}
72+
73+
for (auto _ : state)
74+
{
75+
auto result = xt::roll(input, {shift_h, shift_w}, {0, 1});
76+
benchmark::DoNotOptimize(result.data());
77+
benchmark::ClobberMemory();
78+
}
79+
80+
state.SetItemsProcessed(state.iterations() * h * w);
81+
state.SetBytesProcessed(state.iterations() * h * w * sizeof(double));
82+
}
83+
84+
/*********************************************
85+
* 3D roll benchmarks (2 axes - image spatial roll)
86+
* shift = size * ratio for H and W axes
87+
*********************************************/
88+
89+
inline void roll_3d_2axes_sequential(benchmark::State& state, std::size_t h, std::size_t w, std::size_t c, double ratio)
90+
{
91+
xt::xtensor<double, 3> input = xt::random::rand<double>({h, w, c});
92+
auto shift_h = static_cast<std::ptrdiff_t>(h * ratio);
93+
auto shift_w = static_cast<std::ptrdiff_t>(w * ratio);
94+
95+
for (auto _ : state)
96+
{
97+
xt::xtensor<double, 3> temp = xt::roll(input, shift_h, 0);
98+
xt::xtensor<double, 3> result = xt::roll(temp, shift_w, 1);
99+
benchmark::DoNotOptimize(result.data());
100+
benchmark::ClobberMemory();
101+
}
102+
103+
state.SetItemsProcessed(state.iterations() * h * w * c);
104+
state.SetBytesProcessed(state.iterations() * h * w * c * sizeof(double));
105+
}
106+
107+
inline void roll_3d_2axes_multi(benchmark::State& state, std::size_t h, std::size_t w, std::size_t c, double ratio)
108+
{
109+
xt::xtensor<double, 3> input = xt::random::rand<double>({h, w, c});
110+
auto shift_h = static_cast<std::ptrdiff_t>(h * ratio);
111+
auto shift_w = static_cast<std::ptrdiff_t>(w * ratio);
112+
113+
// Verify correctness
114+
{
115+
xt::xtensor<double, 3> temp = xt::roll(input, shift_h, 0);
116+
xt::xtensor<double, 3> sequential = xt::roll(temp, shift_w, 1);
117+
auto multi = xt::roll(input, {shift_h, shift_w}, {0, 1});
118+
if (!verify_correctness(multi, sequential))
119+
{
120+
state.SkipWithError("Correctness check failed!");
121+
return;
122+
}
123+
}
124+
125+
for (auto _ : state)
126+
{
127+
auto result = xt::roll(input, {shift_h, shift_w}, {0, 1});
128+
benchmark::DoNotOptimize(result.data());
129+
benchmark::ClobberMemory();
130+
}
131+
132+
state.SetItemsProcessed(state.iterations() * h * w * c);
133+
state.SetBytesProcessed(state.iterations() * h * w * c * sizeof(double));
134+
}
135+
136+
/*********************************************
137+
* 3D roll benchmarks (3 axes - cube roll)
138+
* shift = size * ratio for all axes
139+
*********************************************/
140+
141+
inline void roll_3d_3axes_sequential(benchmark::State& state, std::size_t size, double ratio)
142+
{
143+
xt::xtensor<double, 3> input = xt::random::rand<double>({size, size, size});
144+
auto shift = static_cast<std::ptrdiff_t>(size * ratio);
145+
146+
for (auto _ : state)
147+
{
148+
xt::xtensor<double, 3> temp1 = xt::roll(input, shift, 0);
149+
xt::xtensor<double, 3> temp2 = xt::roll(temp1, shift, 1);
150+
xt::xtensor<double, 3> result = xt::roll(temp2, shift, 2);
151+
benchmark::DoNotOptimize(result.data());
152+
benchmark::ClobberMemory();
153+
}
154+
155+
auto total = size * size * size;
156+
state.SetItemsProcessed(state.iterations() * total);
157+
state.SetBytesProcessed(state.iterations() * total * sizeof(double));
158+
}
159+
160+
inline void roll_3d_3axes_multi(benchmark::State& state, std::size_t size, double ratio)
161+
{
162+
xt::xtensor<double, 3> input = xt::random::rand<double>({size, size, size});
163+
auto shift = static_cast<std::ptrdiff_t>(size * ratio);
164+
165+
// Verify correctness
166+
{
167+
xt::xtensor<double, 3> temp1 = xt::roll(input, shift, 0);
168+
xt::xtensor<double, 3> temp2 = xt::roll(temp1, shift, 1);
169+
xt::xtensor<double, 3> sequential = xt::roll(temp2, shift, 2);
170+
auto multi = xt::roll(input, {shift, shift, shift}, {0, 1, 2});
171+
if (!verify_correctness(multi, sequential))
172+
{
173+
state.SkipWithError("Correctness check failed!");
174+
return;
175+
}
176+
}
177+
178+
for (auto _ : state)
179+
{
180+
auto result = xt::roll(input, {shift, shift, shift}, {0, 1, 2});
181+
benchmark::DoNotOptimize(result.data());
182+
benchmark::ClobberMemory();
183+
}
184+
185+
auto total = size * size * size;
186+
state.SetItemsProcessed(state.iterations() * total);
187+
state.SetBytesProcessed(state.iterations() * total * sizeof(double));
188+
}
189+
190+
/*********************************************
191+
* 4D roll benchmarks (4 axes)
192+
* shift = size * ratio for all axes
193+
*********************************************/
194+
195+
inline void roll_4d_4axes_sequential(benchmark::State& state, std::size_t size, double ratio)
196+
{
197+
xt::xtensor<double, 4> input = xt::random::rand<double>({size, size, size, size});
198+
auto shift = static_cast<std::ptrdiff_t>(size * ratio);
199+
200+
for (auto _ : state)
201+
{
202+
xt::xtensor<double, 4> t1 = xt::roll(input, shift, 0);
203+
xt::xtensor<double, 4> t2 = xt::roll(t1, shift, 1);
204+
xt::xtensor<double, 4> t3 = xt::roll(t2, shift, 2);
205+
xt::xtensor<double, 4> result = xt::roll(t3, shift, 3);
206+
benchmark::DoNotOptimize(result.data());
207+
benchmark::ClobberMemory();
208+
}
209+
210+
auto total = size * size * size * size;
211+
state.SetItemsProcessed(state.iterations() * total);
212+
state.SetBytesProcessed(state.iterations() * total * sizeof(double));
213+
}
214+
215+
inline void roll_4d_4axes_multi(benchmark::State& state, std::size_t size, double ratio)
216+
{
217+
xt::xtensor<double, 4> input = xt::random::rand<double>({size, size, size, size});
218+
auto shift = static_cast<std::ptrdiff_t>(size * ratio);
219+
220+
// Verify correctness
221+
{
222+
xt::xtensor<double, 4> t1 = xt::roll(input, shift, 0);
223+
xt::xtensor<double, 4> t2 = xt::roll(t1, shift, 1);
224+
xt::xtensor<double, 4> t3 = xt::roll(t2, shift, 2);
225+
xt::xtensor<double, 4> sequential = xt::roll(t3, shift, 3);
226+
auto multi = xt::roll(input, {shift, shift, shift, shift}, {0, 1, 2, 3});
227+
if (!verify_correctness(multi, sequential))
228+
{
229+
state.SkipWithError("Correctness check failed!");
230+
return;
231+
}
232+
}
233+
234+
for (auto _ : state)
235+
{
236+
auto result = xt::roll(input, {shift, shift, shift, shift}, {0, 1, 2, 3});
237+
benchmark::DoNotOptimize(result.data());
238+
benchmark::ClobberMemory();
239+
}
240+
241+
auto total = size * size * size * size;
242+
state.SetItemsProcessed(state.iterations() * total);
243+
state.SetBytesProcessed(state.iterations() * total * sizeof(double));
244+
}
245+
246+
/*********************************************
247+
* 5D roll benchmarks (5 axes)
248+
* shift = size * ratio for all axes
249+
*********************************************/
250+
251+
inline void roll_5d_5axes_sequential(benchmark::State& state, std::size_t size, double ratio)
252+
{
253+
xt::xtensor<double, 5> input = xt::random::rand<double>({size, size, size, size, size});
254+
auto shift = static_cast<std::ptrdiff_t>(size * ratio);
255+
256+
for (auto _ : state)
257+
{
258+
xt::xtensor<double, 5> t1 = xt::roll(input, shift, 0);
259+
xt::xtensor<double, 5> t2 = xt::roll(t1, shift, 1);
260+
xt::xtensor<double, 5> t3 = xt::roll(t2, shift, 2);
261+
xt::xtensor<double, 5> t4 = xt::roll(t3, shift, 3);
262+
xt::xtensor<double, 5> result = xt::roll(t4, shift, 4);
263+
benchmark::DoNotOptimize(result.data());
264+
benchmark::ClobberMemory();
265+
}
266+
267+
auto total = size * size * size * size * size;
268+
state.SetItemsProcessed(state.iterations() * total);
269+
state.SetBytesProcessed(state.iterations() * total * sizeof(double));
270+
}
271+
272+
inline void roll_5d_5axes_multi(benchmark::State& state, std::size_t size, double ratio)
273+
{
274+
xt::xtensor<double, 5> input = xt::random::rand<double>({size, size, size, size, size});
275+
auto shift = static_cast<std::ptrdiff_t>(size * ratio);
276+
277+
// Verify correctness
278+
{
279+
xt::xtensor<double, 5> t1 = xt::roll(input, shift, 0);
280+
xt::xtensor<double, 5> t2 = xt::roll(t1, shift, 1);
281+
xt::xtensor<double, 5> t3 = xt::roll(t2, shift, 2);
282+
xt::xtensor<double, 5> t4 = xt::roll(t3, shift, 3);
283+
xt::xtensor<double, 5> sequential = xt::roll(t4, shift, 4);
284+
auto multi = xt::roll(input, {shift, shift, shift, shift, shift}, {0, 1, 2, 3, 4});
285+
if (!verify_correctness(multi, sequential))
286+
{
287+
state.SkipWithError("Correctness check failed!");
288+
return;
289+
}
290+
}
291+
292+
for (auto _ : state)
293+
{
294+
auto result = xt::roll(input, {shift, shift, shift, shift, shift}, {0, 1, 2, 3, 4});
295+
benchmark::DoNotOptimize(result.data());
296+
benchmark::ClobberMemory();
297+
}
298+
299+
auto total = size * size * size * size * size;
300+
state.SetItemsProcessed(state.iterations() * total);
301+
state.SetBytesProcessed(state.iterations() * total * sizeof(double));
302+
}
303+
}
304+
305+
/*********************************************
306+
* Rate variation test (3D cube 128, 3 axes)
307+
* Demonstrates that rate does not affect performance
308+
* Rates: 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, -0.3
309+
*********************************************/
310+
311+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r0.01, 128, 0.01);
312+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r0.01, 128, 0.01);
313+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r0.05, 128, 0.05);
314+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r0.05, 128, 0.05);
315+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r0.1, 128, 0.1);
316+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r0.1, 128, 0.1);
317+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r0.2, 128, 0.2);
318+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r0.2, 128, 0.2);
319+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r0.3, 128, 0.3);
320+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r0.3, 128, 0.3);
321+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r0.4, 128, 0.4);
322+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r0.4, 128, 0.4);
323+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 128/r-0.3, 128, -0.3);
324+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 128/r-0.3, 128, -0.3);
325+
326+
/*********************************************
327+
* Main benchmarks (rate = 0.3)
328+
*********************************************/
329+
330+
// 2D square tensors
331+
BENCHMARK_CAPTURE(detail::roll_2d_sequential, 64x64, 64, 64, 0.3);
332+
BENCHMARK_CAPTURE(detail::roll_2d_multi, 64x64, 64, 64, 0.3);
333+
BENCHMARK_CAPTURE(detail::roll_2d_sequential, 256x256, 256, 256, 0.3);
334+
BENCHMARK_CAPTURE(detail::roll_2d_multi, 256x256, 256, 256, 0.3);
335+
BENCHMARK_CAPTURE(detail::roll_2d_sequential, 1024x1024, 1024, 1024, 0.3);
336+
BENCHMARK_CAPTURE(detail::roll_2d_multi, 1024x1024, 1024, 1024, 0.3);
337+
338+
// 3D cube - 2 axes
339+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, 64x64x64, 64, 64, 64, 0.3);
340+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, 64x64x64, 64, 64, 64, 0.3);
341+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, 128x128x128, 128, 128, 128, 0.3);
342+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, 128x128x128, 128, 128, 128, 0.3);
343+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, 256x256x256, 256, 256, 256, 0.3);
344+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, 256x256x256, 256, 256, 256, 0.3);
345+
346+
// 3D cube - 3 axes
347+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 16, 16, 0.3);
348+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 16, 16, 0.3);
349+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 32, 32, 0.3);
350+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 32, 32, 0.3);
351+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_sequential, 64, 64, 0.3);
352+
BENCHMARK_CAPTURE(detail::roll_3d_3axes_multi, 64, 64, 0.3);
353+
354+
// 4D - 4 axes
355+
BENCHMARK_CAPTURE(detail::roll_4d_4axes_sequential, 16, 16, 0.3);
356+
BENCHMARK_CAPTURE(detail::roll_4d_4axes_multi, 16, 16, 0.3);
357+
BENCHMARK_CAPTURE(detail::roll_4d_4axes_sequential, 32, 32, 0.3);
358+
BENCHMARK_CAPTURE(detail::roll_4d_4axes_multi, 32, 32, 0.3);
359+
BENCHMARK_CAPTURE(detail::roll_4d_4axes_sequential, 64, 64, 0.3);
360+
BENCHMARK_CAPTURE(detail::roll_4d_4axes_multi, 64, 64, 0.3);
361+
362+
// 5D - 5 axes
363+
BENCHMARK_CAPTURE(detail::roll_5d_5axes_sequential, 16, 16, 0.3);
364+
BENCHMARK_CAPTURE(detail::roll_5d_5axes_multi, 16, 16, 0.3);
365+
BENCHMARK_CAPTURE(detail::roll_5d_5axes_sequential, 32, 32, 0.3);
366+
BENCHMARK_CAPTURE(detail::roll_5d_5axes_multi, 32, 32, 0.3);
367+
368+
// 3D RGB images (H x W x 3)
369+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_1080p, 1080, 1920, 3, 0.3);
370+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_1080p, 1080, 1920, 3, 0.3);
371+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_2K, 1440, 2560, 3, 0.3);
372+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_2K, 1440, 2560, 3, 0.3);
373+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_4K, 2160, 3840, 3, 0.3);
374+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_4K, 2160, 3840, 3, 0.3);
375+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_8K, 4320, 7680, 3, 0.3);
376+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_8K, 4320, 7680, 3, 0.3);
377+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_256x256, 256, 256, 3, 0.3);
378+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_256x256, 256, 256, 3, 0.3);
379+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_512x512, 512, 512, 3, 0.3);
380+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_512x512, 512, 512, 3, 0.3);
381+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_1024x1024, 1024, 1024, 3, 0.3);
382+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_1024x1024, 1024, 1024, 3, 0.3);
383+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_sequential, rgb_2048x2048, 2048, 2048, 3, 0.3);
384+
BENCHMARK_CAPTURE(detail::roll_3d_2axes_multi, rgb_2048x2048, 2048, 2048, 3, 0.3);
385+
}
386+
}

0 commit comments

Comments
 (0)