Skip to content

Commit 8bd4e0a

Browse files
authored
feat: rework pinned memory management (PROOF-928) (#259)
* rework pinned memory management * rework concepts * drop print statement
1 parent 365cd9a commit 8bd4e0a

File tree

11 files changed

+453
-134
lines changed

11 files changed

+453
-134
lines changed

sxt/base/device/pinned_buffer.cc

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
22
*
3-
* Copyright 2024-present Space and Time Labs, Inc.
3+
* Copyright 2025-present Space and Time Labs, Inc.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -17,17 +17,20 @@
1717
#include "sxt/base/device/pinned_buffer.h"
1818

1919
#include "sxt/base/device/pinned_buffer_pool.h"
20+
#include "sxt/base/error/assert.h"
2021

2122
namespace sxt::basdv {
2223
//--------------------------------------------------------------------------------------------------
23-
// consructor
24+
// constructor
2425
//--------------------------------------------------------------------------------------------------
25-
pinned_buffer::pinned_buffer() noexcept : handle_{get_pinned_buffer_pool()->acquire_handle()} {}
26-
27-
pinned_buffer::pinned_buffer(pinned_buffer&& ptr) noexcept : handle_{ptr.handle_} {
28-
ptr.handle_ = nullptr;
26+
pinned_buffer::pinned_buffer(size_t size) noexcept
27+
: handle_{get_pinned_buffer_pool()->acquire_handle()}, size_{size} {
28+
SXT_RELEASE_ASSERT(size_ <= this->capacity());
2929
}
3030

31+
pinned_buffer::pinned_buffer(pinned_buffer&& other) noexcept
32+
: handle_{std::exchange(other.handle_, nullptr)}, size_{std::exchange(other.size_, 0)} {}
33+
3134
//--------------------------------------------------------------------------------------------------
3235
// destructor
3336
//--------------------------------------------------------------------------------------------------
@@ -40,17 +43,54 @@ pinned_buffer::~pinned_buffer() noexcept {
4043
//--------------------------------------------------------------------------------------------------
4144
// operator=
4245
//--------------------------------------------------------------------------------------------------
43-
pinned_buffer& pinned_buffer::operator=(pinned_buffer&& ptr) noexcept {
44-
if (handle_ != nullptr) {
45-
get_pinned_buffer_pool()->release_handle(handle_);
46-
}
47-
handle_ = ptr.handle_;
48-
ptr.handle_ = nullptr;
46+
pinned_buffer& pinned_buffer::operator=(pinned_buffer&& other) noexcept {
47+
this->reset();
48+
handle_ = std::exchange(other.handle_, nullptr);
49+
size_ = std::exchange(other.size_, 0);
4950
return *this;
5051
}
5152

5253
//--------------------------------------------------------------------------------------------------
53-
// size
54+
// capacity
5455
//--------------------------------------------------------------------------------------------------
55-
size_t pinned_buffer::size() noexcept { return pinned_buffer_size; }
56+
size_t pinned_buffer::capacity() noexcept { return pinned_buffer_size; }
57+
58+
//--------------------------------------------------------------------------------------------------
59+
// resize
60+
//--------------------------------------------------------------------------------------------------
61+
void pinned_buffer::resize(size_t size) noexcept {
62+
SXT_RELEASE_ASSERT(size <= this->capacity());
63+
if (handle_ == nullptr) {
64+
handle_ = get_pinned_buffer_pool()->acquire_handle();
65+
}
66+
size_ = size;
67+
}
68+
69+
//--------------------------------------------------------------------------------------------------
70+
// fill
71+
//--------------------------------------------------------------------------------------------------
72+
basct::cspan<std::byte> pinned_buffer::fill_from_host(basct::cspan<std::byte> src) noexcept {
73+
if (src.empty()) {
74+
return src;
75+
}
76+
if (handle_ == nullptr) {
77+
handle_ = get_pinned_buffer_pool()->acquire_handle();
78+
}
79+
auto n = std::min(src.size(), this->capacity() - size_);
80+
std::copy_n(src.data(), n, static_cast<std::byte*>(handle_->ptr) + size_);
81+
size_ += n;
82+
return src.subspan(n);
83+
}
84+
85+
//--------------------------------------------------------------------------------------------------
86+
// reset
87+
//--------------------------------------------------------------------------------------------------
88+
void pinned_buffer::reset() noexcept {
89+
if (handle_ == nullptr) {
90+
return;
91+
}
92+
get_pinned_buffer_pool()->release_handle(handle_);
93+
handle_ = nullptr;
94+
size_ = 0;
95+
}
5696
} // namespace sxt::basdv

sxt/base/device/pinned_buffer.h

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
22
*
3-
* Copyright 2024-present Space and Time Labs, Inc.
3+
* Copyright 2025-present Space and Time Labs, Inc.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -16,6 +16,9 @@
1616
*/
1717
#pragma once
1818

19+
#include <cstddef>
20+
21+
#include "sxt/base/container/span.h"
1922
#include "sxt/base/device/pinned_buffer_handle.h"
2023

2124
namespace sxt::basdv {
@@ -24,26 +27,48 @@ namespace sxt::basdv {
2427
//--------------------------------------------------------------------------------------------------
2528
class pinned_buffer {
2629
public:
27-
pinned_buffer() noexcept;
28-
pinned_buffer(pinned_buffer&& ptr) noexcept;
30+
pinned_buffer() noexcept = default;
31+
32+
explicit pinned_buffer(size_t size) noexcept;
33+
2934
pinned_buffer(const pinned_buffer&) noexcept = delete;
35+
pinned_buffer(pinned_buffer&& other) noexcept;
3036

3137
~pinned_buffer() noexcept;
3238

33-
pinned_buffer& operator=(pinned_buffer&& ptr) noexcept;
34-
pinned_buffer& operator=(const pinned_buffer& ptr) noexcept = delete;
39+
pinned_buffer& operator=(const pinned_buffer&) noexcept = delete;
40+
pinned_buffer& operator=(pinned_buffer&& other) noexcept;
41+
42+
bool empty() const noexcept { return size_ == 0; }
43+
44+
bool full() const noexcept { return size_ == this->capacity(); }
45+
46+
size_t size() const noexcept { return size_; }
47+
48+
static size_t capacity() noexcept;
3549

36-
static size_t size() noexcept;
50+
void* data() noexcept {
51+
if (handle_ == nullptr) {
52+
return nullptr;
53+
}
54+
return handle_->ptr;
55+
}
3756

38-
void* data() noexcept { return handle_->ptr; }
57+
const void* data() const noexcept {
58+
if (handle_ == nullptr) {
59+
return nullptr;
60+
}
61+
return handle_->ptr;
62+
}
3963

40-
const void* data() const noexcept { return handle_->ptr; }
64+
void resize(size_t size) noexcept;
4165

42-
operator void*() noexcept { return handle_->ptr; }
66+
basct::cspan<std::byte> fill_from_host(basct::cspan<std::byte> src) noexcept;
4367

44-
operator const void*() const noexcept { return handle_->ptr; }
68+
void reset() noexcept;
4569

4670
private:
47-
pinned_buffer_handle* handle_;
71+
pinned_buffer_handle* handle_ = nullptr;
72+
size_t size_ = 0;
4873
};
4974
} // namespace sxt::basdv

sxt/base/device/pinned_buffer.t.cc

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
22
*
3-
* Copyright 2024-present Space and Time Labs, Inc.
3+
* Copyright 2025-present Space and Time Labs, Inc.
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
66
* you may not use this file except in compliance with the License.
@@ -16,40 +16,64 @@
1616
*/
1717
#include "sxt/base/device/pinned_buffer.h"
1818

19-
#include "sxt/base/device/pinned_buffer_pool.h"
2019
#include "sxt/base/test/unit_test.h"
2120

2221
using namespace sxt;
2322
using namespace sxt::basdv;
2423

25-
TEST_CASE("we can manage pinned buffers") {
26-
auto num_buffers = 5u;
27-
auto pool = get_pinned_buffer_pool(num_buffers);
28-
29-
SECTION("we can acquire and release a pinned buffer") {
30-
{
31-
pinned_buffer buf;
32-
REQUIRE(pool->size() == num_buffers - 1);
33-
*reinterpret_cast<char*>(buf.data()) = 1u;
34-
*(reinterpret_cast<char*>(buf.data()) + buf.size() - 1) = 2u;
35-
}
36-
REQUIRE(pool->size() == num_buffers);
24+
TEST_CASE("we can manage a buffer of pinned memory") {
25+
SECTION("we can construct and deconstruct a buffer") {
26+
pinned_buffer buf;
27+
REQUIRE(buf.size() == 0);
28+
REQUIRE(buf.empty());
29+
}
30+
31+
SECTION("we can add a single byte to a buffer") {
32+
pinned_buffer buf;
33+
std::vector<std::byte> data = {std::byte{123}};
34+
auto rest = buf.fill_from_host(data);
35+
REQUIRE(rest.empty());
36+
REQUIRE(buf.size() == 1);
37+
REQUIRE(*static_cast<std::byte*>(buf.data()) == data[0]);
38+
}
39+
40+
SECTION("we can reset a buffer") {
41+
pinned_buffer buf;
42+
std::vector<std::byte> data = {std::byte{123}};
43+
buf.fill_from_host(data);
44+
buf.reset();
45+
REQUIRE(buf.empty());
3746
}
3847

3948
SECTION("we can move construct a buffer") {
40-
pinned_buffer buf1;
41-
auto ptr = buf1.data();
42-
pinned_buffer buf{std::move(buf1)};
43-
REQUIRE(buf.data() == ptr);
44-
REQUIRE(pool->size() == num_buffers - 1);
49+
pinned_buffer buf;
50+
std::vector<std::byte> data = {static_cast<std::byte>(123)};
51+
buf.fill_from_host(data);
52+
pinned_buffer buf_p{std::move(buf)};
53+
REQUIRE(buf.empty());
54+
REQUIRE(buf_p.size() == 1);
55+
REQUIRE(*static_cast<std::byte*>(buf_p.data()) == data[0]);
56+
}
57+
58+
SECTION("we can move assign a buffer") {
59+
pinned_buffer buf;
60+
std::vector<std::byte> data = {std::byte{123}};
61+
buf.fill_from_host(data);
62+
63+
pinned_buffer buf_p;
64+
data[0] = std::byte{3};
65+
buf_p.fill_from_host(data);
66+
buf_p = std::move(buf);
67+
REQUIRE(buf.empty());
68+
REQUIRE(*static_cast<std::byte*>(buf_p.data()) == std::byte{123});
4569
}
4670

47-
SECTION("we can move-assign a buffer") {
48-
pinned_buffer buf1;
49-
auto ptr = buf1.data();
71+
SECTION("we can fill a buffer") {
5072
pinned_buffer buf;
51-
buf = std::move(buf1);
52-
REQUIRE(buf.data() == ptr);
53-
REQUIRE(pool->size() == num_buffers - 1);
73+
std::vector<std::byte> data(buf.capacity() + 1, std::byte{123});
74+
auto rest = buf.fill_from_host(data);
75+
REQUIRE(rest.size() == 1);
76+
REQUIRE(buf.size() == buf.capacity());
77+
REQUIRE(buf.full());
5478
}
5579
}

sxt/execution/device/BUILD

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,31 @@ sxt_cc_component(
1111
],
1212
)
1313

14+
sxt_cc_component(
15+
name = "to_device_copier",
16+
impl_deps = [
17+
":synchronization",
18+
"//sxt/base/device:memory_utility",
19+
"//sxt/base/functional:function_ref",
20+
"//sxt/execution/async:coroutine",
21+
],
22+
test_deps = [
23+
"//sxt/base/device:pinned_buffer_pool",
24+
"//sxt/base/device:synchronization",
25+
"//sxt/base/test:unit_test",
26+
"//sxt/execution/schedule:scheduler",
27+
"//sxt/memory/management:managed_array",
28+
"//sxt/memory/resource:managed_device_resource",
29+
],
30+
deps = [
31+
"//sxt/base/container:span",
32+
"//sxt/base/device:pinned_buffer",
33+
"//sxt/base/device:stream",
34+
"//sxt/base/error:assert",
35+
"//sxt/execution/async:future",
36+
],
37+
)
38+
1439
sxt_cc_component(
1540
name = "device_viewable",
1641
test_deps = [
@@ -192,10 +217,7 @@ sxt_cc_component(
192217

193218
sxt_cc_component(
194219
name = "generate",
195-
impl_deps = [
196-
],
197220
test_deps = [
198-
"//sxt/base/device:pinned_buffer",
199221
"//sxt/base/device:stream",
200222
"//sxt/base/device:synchronization",
201223
"//sxt/base/test:unit_test",
@@ -217,7 +239,7 @@ sxt_cc_component(
217239
sxt_cc_component(
218240
name = "copy",
219241
impl_deps = [
220-
":generate",
242+
":to_device_copier",
221243
"//sxt/base/device:memory_utility",
222244
"//sxt/base/device:stream",
223245
"//sxt/execution/async:coroutine",

0 commit comments

Comments
 (0)