Skip to content

Commit 1f579c0

Browse files
authored
feat: add function for scheduling chunked computations (PROOF-928) (#255)
* add function to schedule asynchronous chunked computations * remove print statements * doc * change signature * drop dead code
1 parent 6ee8dfb commit 1f579c0

File tree

6 files changed

+298
-3
lines changed

6 files changed

+298
-3
lines changed

sxt/execution/device/BUILD

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ load(
33
"sxt_cc_component",
44
)
55

6+
sxt_cc_component(
7+
name = "chunk_context",
8+
with_test = False,
9+
deps = [
10+
"//sxt/execution/async:shared_future",
11+
],
12+
)
13+
614
sxt_cc_component(
715
name = "device_viewable",
816
test_deps = [
@@ -78,19 +86,24 @@ sxt_cc_component(
7886
name = "for_each",
7987
impl_deps = [
8088
":available_device",
81-
"//sxt/execution/async:future",
8289
"//sxt/execution/async:coroutine",
8390
"//sxt/base/device:active_device_guard",
8491
"//sxt/base/device:property",
92+
"//sxt/base/device:state",
8593
"//sxt/base/iterator:split",
8694
],
8795
test_deps = [
96+
"//sxt/base/error:assert",
8897
"//sxt/base/iterator:index_range",
98+
"//sxt/base/iterator:index_range_iterator",
8999
"//sxt/base/test:unit_test",
90100
"//sxt/execution/async:future",
91101
],
92102
deps = [
93-
"//sxt/execution/async:future_fwd",
103+
":chunk_context",
104+
"//sxt/base/device:stream",
105+
"//sxt/execution/async:future",
106+
"//sxt/execution/async:shared_future",
94107
"//sxt/execution/schedule:scheduler",
95108
],
96109
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
2+
*
3+
* Copyright 2025-present Space and Time Labs, Inc.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#include "sxt/execution/device/chunk_context.h"
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
2+
*
3+
* Copyright 2025-present Space and Time Labs, Inc.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
19+
#include "sxt/execution/async/shared_future.h"
20+
21+
namespace sxt::xendv {
22+
//--------------------------------------------------------------------------------------------------
23+
// chunk_context
24+
//--------------------------------------------------------------------------------------------------
25+
/**
26+
* Give context for an individual chunk of a chunked computation
27+
*/
28+
struct chunk_context {
29+
// a counter tracking the processing index for the given chunk
30+
unsigned chunk_index = 0;
31+
32+
// the device used to process the chunk
33+
unsigned device_index = 0;
34+
35+
// the total number of devices used to process the collection of chunks
36+
unsigned num_devices_used = 0;
37+
38+
// When two chunks are scheduled for the same device, alt_future gives
39+
// a handle to the asynchronous computation associated with the other
40+
// chunk.
41+
//
42+
// alt_future can be used to overlap memory transfer with kernel computation. For
43+
// example, a functor to process chunks might look something like this
44+
// f(const chunk_context& ctx, index_range rng) noexcept -> xena::future<> {
45+
// ...
46+
// async_copy_memory(stream, ...);
47+
//
48+
// co_await ctx.alt_future;
49+
// // wait for the other future to finish so that we don't oversubscribe the GPU
50+
//
51+
// launch_kernel(stream, ...);
52+
// co_await synchronize_stream(stream);
53+
// }
54+
xena::shared_future<> alt_future;
55+
};
56+
} // namespace sxt::xendv

sxt/execution/device/for_each.cc

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,37 @@
1818

1919
#include "sxt/base/device/active_device_guard.h"
2020
#include "sxt/base/device/property.h"
21+
#include "sxt/base/device/state.h"
2122
#include "sxt/base/iterator/split.h"
2223
#include "sxt/execution/async/coroutine.h"
2324
#include "sxt/execution/async/future.h"
2425
#include "sxt/execution/device/available_device.h"
2526

2627
namespace sxt::xendv {
28+
//--------------------------------------------------------------------------------------------------
29+
// for_each_device_impl
30+
//--------------------------------------------------------------------------------------------------
31+
static xena::future<> for_each_device_impl(
32+
chunk_context* ctx, chunk_context* ctx_p, unsigned& chunk_index,
33+
basit::index_range_iterator& iter, basit::index_range_iterator last,
34+
std::function<xena::future<>(const chunk_context& ctx, basit::index_range)> f) noexcept {
35+
auto device_index = ctx->device_index;
36+
while (true) {
37+
if (iter == last) {
38+
co_await ctx_p->alt_future;
39+
co_return;
40+
}
41+
auto chunk = *iter++;
42+
ctx_p->chunk_index = chunk_index++;
43+
basdv::set_device(device_index);
44+
auto fut = f(*ctx_p, chunk);
45+
co_await ctx_p->alt_future;
46+
ctx->alt_future = std::move(fut);
47+
std::swap(ctx, ctx_p);
48+
}
49+
co_await ctx->alt_future;
50+
}
51+
2752
//--------------------------------------------------------------------------------------------------
2853
// concurrent_for_each
2954
//--------------------------------------------------------------------------------------------------
@@ -51,4 +76,53 @@ concurrent_for_each(basit::index_range rng,
5176
auto [first, last] = basit::split(rng, split_options);
5277
return concurrent_for_each(first, last, f);
5378
}
79+
80+
//--------------------------------------------------------------------------------------------------
81+
// for_each_device
82+
//--------------------------------------------------------------------------------------------------
83+
xena::future<> for_each_device(
84+
basit::index_range_iterator first, basit::index_range_iterator last,
85+
std::function<xena::future<>(const chunk_context& ctx, basit::index_range)> f) noexcept {
86+
if (first == last) {
87+
co_return;
88+
}
89+
90+
unsigned chunk_index = 0;
91+
auto num_chunks = static_cast<unsigned>(std::distance(first, last));
92+
auto num_devices = basdv::get_num_devices();
93+
auto num_devices_used = static_cast<unsigned>(std::min(num_chunks, num_devices));
94+
95+
basdv::active_device_guard guard;
96+
97+
// set up contexts
98+
std::vector<chunk_context> contexts(num_devices_used);
99+
for (unsigned device_index = 0; device_index < num_devices_used; ++device_index) {
100+
auto& ctx = contexts[device_index];
101+
ctx.device_index = device_index;
102+
ctx.alt_future = xena::make_ready_future();
103+
ctx.num_devices_used = num_devices_used;
104+
}
105+
std::vector<chunk_context> contexts_p(contexts);
106+
107+
// initial launches
108+
for (unsigned device_index = 0; device_index < num_devices_used; ++device_index) {
109+
auto& ctx = contexts[device_index];
110+
ctx.chunk_index = chunk_index++;
111+
auto chunk = *first++;
112+
basdv::set_device(device_index);
113+
contexts_p[device_index].alt_future = f(ctx, chunk);
114+
}
115+
116+
// continue launching until all chunks are processed
117+
std::vector<xena::future<>> futs(num_devices_used);
118+
for (unsigned device_index = 0; device_index < num_devices_used; ++device_index) {
119+
futs[device_index] = for_each_device_impl(&contexts[device_index], &contexts_p[device_index],
120+
chunk_index, first, last, f);
121+
}
122+
123+
// wait for everything to finish
124+
for (auto& fut : futs) {
125+
co_await std::move(fut);
126+
}
127+
}
54128
} // namespace sxt::xendv

sxt/execution/device/for_each.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717
#pragma once
1818

1919
#include <functional>
20+
#include <optional>
2021

21-
#include "sxt/execution/async/future_fwd.h"
22+
#include "sxt/base/device/stream.h"
23+
#include "sxt/execution/async/future.h"
24+
#include "sxt/execution/async/shared_future.h"
25+
#include "sxt/execution/device/chunk_context.h"
2226

2327
namespace sxt::basit {
2428
class index_range;
@@ -44,4 +48,15 @@ concurrent_for_each(basit::index_range_iterator first, basit::index_range_iterat
4448
xena::future<>
4549
concurrent_for_each(basit::index_range rng,
4650
std::function<xena::future<>(const basit::index_range&)> f) noexcept;
51+
52+
//--------------------------------------------------------------------------------------------------
53+
// for_each_device
54+
//--------------------------------------------------------------------------------------------------
55+
/**
56+
* Invoke the function f on the range of chunks provided, splitting the work across available
57+
* devices.
58+
*/
59+
xena::future<> for_each_device(
60+
basit::index_range_iterator first, basit::index_range_iterator last,
61+
std::function<xena::future<>(const chunk_context& ctx, basit::index_range)> f) noexcept;
4762
} // namespace sxt::xendv

sxt/execution/device/for_each.t.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,15 @@
1616
*/
1717
#include "sxt/execution/device/for_each.h"
1818

19+
#include <algorithm>
20+
#include <numeric>
21+
#include <random>
1922
#include <utility>
2023
#include <vector>
2124

25+
#include "sxt/base/error/assert.h"
2226
#include "sxt/base/iterator/index_range.h"
27+
#include "sxt/base/iterator/index_range_iterator.h"
2328
#include "sxt/base/test/unit_test.h"
2429
#include "sxt/execution/async/future.h"
2530

@@ -67,3 +72,118 @@ TEST_CASE("we can concurrently invoke code on different GPUs") {
6772
REQUIRE(t == 11);
6873
}
6974
}
75+
76+
TEST_CASE("we can manage asynchronous chunked computations") {
77+
std::vector<std::pair<unsigned, unsigned>> ranges;
78+
std::vector<xena::promise<int>> promises(10);
79+
80+
SECTION("we iterate over no chunks") {
81+
basit::index_range_iterator iter{basit::index_range{2, 2}, 1};
82+
auto fut = for_each_device(
83+
iter, iter, [&](const chunk_context& ctx, basit::index_range rng) -> xena::future<> {
84+
return xena::future<int>{promises[0]}.then([](int /*val*/) noexcept {});
85+
});
86+
REQUIRE(fut.ready());
87+
}
88+
89+
SECTION("we can iterate over a single chunk") {
90+
basit::index_range_iterator first{basit::index_range{0, 1}, 1};
91+
basit::index_range_iterator last{basit::index_range{1, 1}, 1};
92+
auto fut = for_each_device(
93+
first, last, [&](const chunk_context& ctx, basit::index_range rng) -> xena::future<> {
94+
ranges.emplace_back(rng.a(), rng.b());
95+
return xena::future<int>{promises[0]}.then(
96+
[&](int val) noexcept { SXT_RELEASE_ASSERT(val == 123); });
97+
});
98+
REQUIRE(!fut.ready());
99+
promises[0].set_value(123);
100+
REQUIRE(fut.ready());
101+
std::vector<std::pair<unsigned, unsigned>> expected = {{0, 1}};
102+
REQUIRE(ranges == expected);
103+
}
104+
105+
SECTION("we can iterate over two chunks") {
106+
basit::index_range_iterator first{basit::index_range{0, 2}, 1};
107+
basit::index_range_iterator last{basit::index_range{2, 2}, 1};
108+
auto fut = for_each_device(
109+
first, last, [&](const chunk_context& ctx, basit::index_range rng) -> xena::future<> {
110+
ranges.emplace_back(rng.a(), rng.b());
111+
return xena::future<int>{promises[ctx.chunk_index]}.then(
112+
[chunk_index = ctx.chunk_index](int val) noexcept {
113+
if (chunk_index == 0) {
114+
SXT_RELEASE_ASSERT(val == 123);
115+
} else {
116+
SXT_RELEASE_ASSERT(val == 456);
117+
}
118+
});
119+
});
120+
REQUIRE(!fut.ready());
121+
promises[0].set_value(123);
122+
REQUIRE(!fut.ready());
123+
promises[1].set_value(456);
124+
REQUIRE(fut.ready());
125+
std::vector<std::pair<unsigned, unsigned>> expected = {{0, 1}, {1, 2}};
126+
REQUIRE(ranges == expected);
127+
}
128+
129+
SECTION("we can iterate over different chunk sizes") {
130+
for (unsigned k = 3; k < 10; ++k) {
131+
promises.clear();
132+
ranges.clear();
133+
promises.resize(k);
134+
basit::index_range_iterator first{basit::index_range{0, k}, 1};
135+
basit::index_range_iterator last{basit::index_range{k, k}, 1};
136+
auto fut = for_each_device(
137+
first, last, [&](const chunk_context& ctx, basit::index_range rng) -> xena::future<> {
138+
ranges.emplace_back(rng.a(), rng.b());
139+
return xena::future<int>{promises[ctx.chunk_index]}.then(
140+
[chunk_index = ctx.chunk_index](int val) noexcept {
141+
SXT_RELEASE_ASSERT(val == chunk_index);
142+
});
143+
});
144+
std::vector<std::pair<unsigned, unsigned>> expected;
145+
for (unsigned i = 0; i < k; ++i) {
146+
REQUIRE(!fut.ready());
147+
promises[i].set_value(i);
148+
expected.emplace_back(i, i + 1);
149+
}
150+
REQUIRE(fut.ready());
151+
REQUIRE(ranges == expected);
152+
}
153+
}
154+
155+
SECTION("we can iterate over different chunks finished in an arbitrary order") {
156+
std::mt19937 rng{0};
157+
158+
for (unsigned k = 3; k < 10; ++k) {
159+
promises.clear();
160+
promises.resize(k);
161+
std::vector<bool> finished(k);
162+
std::vector<xena::future<int>> futs;
163+
for (auto& ps : promises) {
164+
futs.emplace_back(ps);
165+
}
166+
basit::index_range_iterator first{basit::index_range{0, k}, 1};
167+
basit::index_range_iterator last{basit::index_range{k, k}, 1};
168+
auto fut = for_each_device(
169+
first, last, [&](const chunk_context& ctx, basit::index_range rng) -> xena::future<> {
170+
return futs[ctx.chunk_index].then(
171+
[&finished, chunk_index = ctx.chunk_index](int val) noexcept {
172+
finished[chunk_index] = true;
173+
SXT_RELEASE_ASSERT(val == chunk_index);
174+
});
175+
});
176+
std::vector<std::pair<unsigned, unsigned>> expected;
177+
std::vector<unsigned> ix(k);
178+
std::iota(ix.begin(), ix.end(), 0);
179+
std::shuffle(ix.begin(), ix.end(), rng);
180+
for (auto i : ix) {
181+
REQUIRE(!fut.ready());
182+
promises[i].set_value(i);
183+
expected.emplace_back(i, i + 1);
184+
}
185+
REQUIRE(fut.ready());
186+
REQUIRE(std::count(finished.begin(), finished.end(), true) == k);
187+
}
188+
}
189+
}

0 commit comments

Comments
 (0)