Skip to content

Commit 1d17c3b

Browse files
authored
Migrate Python/Cython bindings from device_memory_resource to device_async_resource_ref (#2300)
## Summary Replaces `shared_ptr[device_memory_resource]` with per-subclass `unique_ptr[ConcreteType]` (owning) and `optional[device_async_resource_ref]` (non-owning reference) across all Python/Cython bindings. This is a part of #2011. There are **significant** opportunities to make this Cython code better over time but I have to get something that removes `device_memory_resource` from the Python/Cython side before I can finish migration on the C++ side (#2296). I welcome critique of this design, and ideas for how it can be improved, particularly from @vyasr @wence-. I would like to address any suggested improvements in follow-up PRs, because this changeset is necessary to unblock #2301. The changes in `cdef class DeviceMemoryResource` are perhaps the most significant changes here from a design perspective. The solution I'm going with for now is to keep the `DeviceMemoryResource` class around, as a base class for the Cython MRs, and let it handle allocate/deallocate. It owns a `optional[device_async_resource_ref]` which is used for allocation/deallocation. It's `optional` so that the class can be default-constructed (Cython requires nullary constructors), but it should never be `nullopt` except during initialization. Then, each MR class owns a `c_obj` like `unique_ptr[cuda_memory_resource]`. This is `unique_ptr` so it can be default-constructed for Cython's requirements. I chose `unique_ptr` over `optional` here to emphasize that this member is the thing that actually owns the resource. As with the `c_ref`, this should never be `nullptr` except during initialization. When an MR class is created, it initializes its `c_obj` and then constructs a `c_ref` (a member inherited from the `DeviceMemoryResource` base class). "Special" methods for an MR like getting the statistics counts go through `deref(self.c_obj)`, and "common" methods like allocate/deallocate go through `self.c_ref.value()`. ### Changes - **`.pxd` declarations**: Remove `device_memory_resource` class. Declare `device_async_resource_ref` and a `make_device_async_resource_ref()` inline C++ template that returns `optional` to work around Cython generating default-constructed temporaries for non-default-constructible types. All adaptor constructors take `device_async_resource_ref` instead of `device_memory_resource*`. - **`.pxd` class definitions**: `DeviceMemoryResource` base holds `optional[device_async_resource_ref] c_ref`; each concrete subclass holds `unique_ptr[ConcreteType] c_obj`. - **`.pyx` implementations**: All `__cinit__` methods construct via `unique_ptr` then set `c_ref` via `make_device_async_resource_ref`. Typed accessors (`pool_size`, `flush`, etc.) use `deref(self.c_obj)`. Per-device functions use `set_per_device_resource_ref`. - **`device_buffer.pyx`**: Passes `self.mr.c_ref.value()` instead of `self.mr.get_mr()`. Closes #2294
1 parent e4f4106 commit 1d17c3b

File tree

9 files changed

+308
-276
lines changed

9 files changed

+308
-276
lines changed

python/rmm/rmm/librmm/device_buffer.pxd

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
from rmm.librmm.cuda_stream_view cimport cuda_stream_view
5-
from rmm.librmm.memory_resource cimport device_memory_resource
5+
from rmm.librmm.memory_resource cimport device_async_resource_ref
66

77

88
cdef extern from "rmm/mr/per_device_resource.hpp" namespace "rmm" nogil:
@@ -26,18 +26,18 @@ cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
2626
device_buffer(
2727
size_t size,
2828
cuda_stream_view stream,
29-
device_memory_resource *
29+
device_async_resource_ref mr
3030
) except +
3131
device_buffer(
3232
const void* source_data,
3333
size_t size,
3434
cuda_stream_view stream,
35-
device_memory_resource *
35+
device_async_resource_ref mr
3636
) except +
3737
device_buffer(
3838
const device_buffer buf,
3939
cuda_stream_view stream,
40-
device_memory_resource *
40+
device_async_resource_ref mr
4141
) except +
4242
void reserve(size_t new_capacity, cuda_stream_view stream) except +
4343
void resize(size_t new_size, cuda_stream_view stream) except +

python/rmm/rmm/librmm/device_uvector.pxd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
from rmm.librmm.cuda_stream_view cimport cuda_stream_view
55
from rmm.librmm.device_buffer cimport device_buffer
6-
from rmm.librmm.memory_resource cimport device_memory_resource
76

87

98
cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
@@ -25,4 +24,3 @@ cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
2524
size_t capacity()
2625
T* data()
2726
size_t size()
28-
device_memory_resource* memory_resource()

python/rmm/rmm/librmm/memory_resource.pxd

Lines changed: 106 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,35 +14,72 @@ from libcpp.pair cimport pair
1414
from libcpp.string cimport string
1515

1616
from rmm.librmm.cuda_stream_view cimport cuda_stream_view
17-
from rmm.librmm.memory_resource cimport device_memory_resource
1817

1918

20-
cdef extern from "rmm/mr/device_memory_resource.hpp" \
21-
namespace "rmm::mr" nogil:
22-
cdef cppclass device_memory_resource:
23-
# Legacy functions
24-
void* allocate(size_t bytes) except +
25-
void* allocate(size_t bytes, cuda_stream_view stream) except +
26-
void deallocate(void* ptr, size_t bytes) noexcept
27-
void deallocate(
28-
void* ptr,
29-
size_t bytes,
30-
cuda_stream_view stream
31-
) noexcept
32-
# End legacy functions
33-
34-
void* allocate_sync(size_t bytes) except +
35-
void deallocate_sync(void* ptr, size_t bytes) noexcept
36-
void* allocate(
37-
cuda_stream_view stream,
38-
size_t bytes
39-
) except +
19+
cdef extern from "rmm/resource_ref.hpp" namespace "rmm" nogil:
20+
cdef cppclass device_async_resource_ref:
21+
void* allocate(cuda_stream_view stream, size_t bytes) except +
4022
void deallocate(
4123
cuda_stream_view stream,
4224
void* ptr,
4325
size_t bytes
4426
) noexcept
4527

28+
29+
# Inline C++ helper to construct optional[device_async_resource_ref] from any
30+
# concrete resource type. Returns optional so that Cython assignment
31+
# (self.c_ref = make_device_async_resource_ref(...)) uses optional's
32+
# default-constructible temporary instead of device_async_resource_ref's
33+
# non-default-constructible one.
34+
cdef extern from *:
35+
"""
36+
#include <optional>
37+
#include <rmm/resource_ref.hpp>
38+
template <typename T>
39+
std::optional<rmm::device_async_resource_ref>
40+
make_device_async_resource_ref(T& r) {
41+
return std::optional<rmm::device_async_resource_ref>(
42+
rmm::device_async_resource_ref(r));
43+
}
44+
"""
45+
optional[device_async_resource_ref] make_device_async_resource_ref(
46+
cuda_memory_resource&) except +
47+
optional[device_async_resource_ref] make_device_async_resource_ref(
48+
managed_memory_resource&) except +
49+
optional[device_async_resource_ref] make_device_async_resource_ref(
50+
system_memory_resource&) except +
51+
optional[device_async_resource_ref] make_device_async_resource_ref(
52+
pinned_host_memory_resource&) except +
53+
optional[device_async_resource_ref] make_device_async_resource_ref(
54+
sam_headroom_memory_resource&) except +
55+
optional[device_async_resource_ref] make_device_async_resource_ref(
56+
cuda_async_memory_resource&) except +
57+
optional[device_async_resource_ref] make_device_async_resource_ref(
58+
cuda_async_view_memory_resource&) except +
59+
optional[device_async_resource_ref] make_device_async_resource_ref(
60+
cuda_async_managed_memory_resource&) except +
61+
optional[device_async_resource_ref] make_device_async_resource_ref(
62+
pool_memory_resource&) except +
63+
optional[device_async_resource_ref] make_device_async_resource_ref(
64+
arena_memory_resource&) except +
65+
optional[device_async_resource_ref] make_device_async_resource_ref(
66+
fixed_size_memory_resource&) except +
67+
optional[device_async_resource_ref] make_device_async_resource_ref(
68+
binning_memory_resource&) except +
69+
optional[device_async_resource_ref] make_device_async_resource_ref(
70+
callback_memory_resource&) except +
71+
optional[device_async_resource_ref] make_device_async_resource_ref(
72+
limiting_resource_adaptor&) except +
73+
optional[device_async_resource_ref] make_device_async_resource_ref(
74+
logging_resource_adaptor&) except +
75+
optional[device_async_resource_ref] make_device_async_resource_ref(
76+
statistics_resource_adaptor&) except +
77+
optional[device_async_resource_ref] make_device_async_resource_ref(
78+
tracking_resource_adaptor&) except +
79+
optional[device_async_resource_ref] make_device_async_resource_ref(
80+
prefetch_resource_adaptor&) except +
81+
82+
4683
cdef extern from "rmm/cuda_device.hpp" namespace "rmm" nogil:
4784
size_t percent_of_free_device_memory(int percent) except +
4885
pair[size_t, size_t] available_device_memory() except +
@@ -87,33 +124,33 @@ cdef extern from *:
87124

88125
cdef extern from "rmm/mr/cuda_memory_resource.hpp" \
89126
namespace "rmm::mr" nogil:
90-
cdef cppclass cuda_memory_resource(device_memory_resource):
127+
cdef cppclass cuda_memory_resource:
91128
cuda_memory_resource() except +
92129

93130
cdef extern from "rmm/mr/managed_memory_resource.hpp" \
94131
namespace "rmm::mr" nogil:
95-
cdef cppclass managed_memory_resource(device_memory_resource):
132+
cdef cppclass managed_memory_resource:
96133
managed_memory_resource() except +
97134

98135
cdef extern from "rmm/mr/system_memory_resource.hpp" \
99136
namespace "rmm::mr" nogil:
100-
cdef cppclass system_memory_resource(device_memory_resource):
137+
cdef cppclass system_memory_resource:
101138
system_memory_resource() except +
102139

103140
cdef extern from "rmm/mr/pinned_host_memory_resource.hpp" \
104141
namespace "rmm::mr" nogil:
105-
cdef cppclass pinned_host_memory_resource(device_memory_resource):
142+
cdef cppclass pinned_host_memory_resource:
106143
pinned_host_memory_resource() except +
107144

108145
cdef extern from "rmm/mr/sam_headroom_memory_resource.hpp" \
109146
namespace "rmm::mr" nogil:
110-
cdef cppclass sam_headroom_memory_resource(device_memory_resource):
147+
cdef cppclass sam_headroom_memory_resource:
111148
sam_headroom_memory_resource(size_t headroom) except +
112149

113150
cdef extern from "rmm/mr/cuda_async_memory_resource.hpp" \
114151
namespace "rmm::mr" nogil:
115152

116-
cdef cppclass cuda_async_memory_resource(device_memory_resource):
153+
cdef cppclass cuda_async_memory_resource:
117154
cuda_async_memory_resource(
118155
optional[size_t] initial_pool_size,
119156
optional[size_t] release_threshold,
@@ -122,15 +159,15 @@ cdef extern from "rmm/mr/cuda_async_memory_resource.hpp" \
122159
cdef extern from "rmm/mr/cuda_async_view_memory_resource.hpp" \
123160
namespace "rmm::mr" nogil:
124161

125-
cdef cppclass cuda_async_view_memory_resource(device_memory_resource):
162+
cdef cppclass cuda_async_view_memory_resource:
126163
cuda_async_view_memory_resource(
127164
cudaMemPool_t pool_handle) except +
128165
cudaMemPool_t pool_handle() const
129166

130167
cdef extern from "rmm/mr/cuda_async_managed_memory_resource.hpp" \
131168
namespace "rmm::mr" nogil:
132169

133-
cdef cppclass cuda_async_managed_memory_resource(device_memory_resource):
170+
cdef cppclass cuda_async_managed_memory_resource:
134171
cuda_async_managed_memory_resource() except +
135172
cudaMemPool_t pool_handle() const
136173

@@ -148,27 +185,27 @@ cdef extern from "rmm/mr/cuda_async_memory_resource.hpp" \
148185

149186
cdef extern from "rmm/mr/pool_memory_resource.hpp" \
150187
namespace "rmm::mr" nogil:
151-
cdef cppclass pool_memory_resource(device_memory_resource):
188+
cdef cppclass pool_memory_resource:
152189
pool_memory_resource(
153-
device_memory_resource* upstream_mr,
190+
device_async_resource_ref upstream_mr,
154191
size_t initial_pool_size,
155192
optional[size_t] maximum_pool_size) except +
156193
size_t pool_size()
157194

158195
cdef extern from "rmm/mr/arena_memory_resource.hpp" \
159196
namespace "rmm::mr" nogil:
160-
cdef cppclass arena_memory_resource(device_memory_resource):
197+
cdef cppclass arena_memory_resource:
161198
arena_memory_resource(
162-
device_memory_resource* upstream_mr,
199+
device_async_resource_ref upstream_mr,
163200
optional[size_t] arena_size,
164201
bool dump_log_on_failure
165202
) except +
166203

167204
cdef extern from "rmm/mr/fixed_size_memory_resource.hpp" \
168205
namespace "rmm::mr" nogil:
169-
cdef cppclass fixed_size_memory_resource(device_memory_resource):
206+
cdef cppclass fixed_size_memory_resource:
170207
fixed_size_memory_resource(
171-
device_memory_resource* upstream_mr,
208+
device_async_resource_ref upstream_mr,
172209
size_t block_size,
173210
size_t block_to_preallocate) except +
174211

@@ -177,7 +214,7 @@ cdef extern from "rmm/mr/callback_memory_resource.hpp" \
177214
ctypedef void* (*allocate_callback_t)(size_t, cuda_stream_view, void*)
178215
ctypedef void (*deallocate_callback_t)(void*, size_t, cuda_stream_view, void*)
179216

180-
cdef cppclass callback_memory_resource(device_memory_resource):
217+
cdef cppclass callback_memory_resource:
181218
callback_memory_resource(
182219
allocate_callback_t allocate_callback,
183220
deallocate_callback_t deallocate_callback,
@@ -187,48 +224,50 @@ cdef extern from "rmm/mr/callback_memory_resource.hpp" \
187224

188225
cdef extern from "rmm/mr/binning_memory_resource.hpp" \
189226
namespace "rmm::mr" nogil:
190-
cdef cppclass binning_memory_resource(device_memory_resource):
191-
binning_memory_resource(device_memory_resource* upstream_mr) except +
227+
cdef cppclass binning_memory_resource:
228+
binning_memory_resource(
229+
device_async_resource_ref upstream_mr) except +
192230
binning_memory_resource(
193-
device_memory_resource* upstream_mr,
231+
device_async_resource_ref upstream_mr,
194232
int8_t min_size_exponent,
195233
int8_t max_size_exponent) except +
196234

197-
void add_bin(size_t allocation_size) except +
198235
void add_bin(
199236
size_t allocation_size,
200-
device_memory_resource* bin_resource) except +
237+
optional[device_async_resource_ref] bin_resource
238+
) except +
201239

202240
cdef extern from "rmm/mr/limiting_resource_adaptor.hpp" \
203241
namespace "rmm::mr" nogil:
204-
cdef cppclass limiting_resource_adaptor(device_memory_resource):
242+
cdef cppclass limiting_resource_adaptor:
205243
limiting_resource_adaptor(
206-
device_memory_resource* upstream_mr,
244+
device_async_resource_ref upstream_mr,
207245
size_t allocation_limit) except +
208246

209247
size_t get_allocated_bytes() except +
210248
size_t get_allocation_limit() except +
211249

212250
cdef extern from "rmm/mr/logging_resource_adaptor.hpp" \
213251
namespace "rmm::mr" nogil:
214-
cdef cppclass logging_resource_adaptor(device_memory_resource):
252+
cdef cppclass logging_resource_adaptor:
215253
logging_resource_adaptor(
216-
device_memory_resource* upstream_mr,
254+
device_async_resource_ref upstream_mr,
217255
string filename) except +
218256

219257
void flush() except +
220258

221259
cdef extern from "rmm/mr/statistics_resource_adaptor.hpp" \
222260
namespace "rmm::mr" nogil:
223-
cdef cppclass statistics_resource_adaptor(device_memory_resource):
261+
cdef cppclass statistics_resource_adaptor:
224262
struct counter:
225263
counter()
226264

227265
int64_t value
228266
int64_t peak
229267
int64_t total
230268

231-
statistics_resource_adaptor(device_memory_resource* upstream_mr) except +
269+
statistics_resource_adaptor(
270+
device_async_resource_ref upstream_mr) except +
232271

233272
counter get_bytes_counter() except +
234273
counter get_allocations_counter() except +
@@ -237,9 +276,9 @@ cdef extern from "rmm/mr/statistics_resource_adaptor.hpp" \
237276

238277
cdef extern from "rmm/mr/tracking_resource_adaptor.hpp" \
239278
namespace "rmm::mr" nogil:
240-
cdef cppclass tracking_resource_adaptor(device_memory_resource):
279+
cdef cppclass tracking_resource_adaptor:
241280
tracking_resource_adaptor(
242-
device_memory_resource* upstream_mr,
281+
device_async_resource_ref upstream_mr,
243282
bool capture_stacks) except +
244283

245284
size_t get_allocated_bytes() except +
@@ -253,16 +292,28 @@ cdef extern from "rmm/error.hpp" namespace "rmm" nogil:
253292
cdef extern from "rmm/mr/failure_callback_resource_adaptor.hpp" \
254293
namespace "rmm::mr" nogil:
255294
ctypedef bool (*failure_callback_t)(size_t, void*)
256-
cdef cppclass failure_callback_resource_adaptor[ExceptionType](
257-
device_memory_resource
258-
):
295+
cdef cppclass failure_callback_resource_adaptor[ExceptionType]:
259296
failure_callback_resource_adaptor(
260-
device_memory_resource* upstream_mr,
297+
device_async_resource_ref upstream_mr,
261298
failure_callback_t callback,
262299
void* callback_arg
263300
) except +
264301

302+
ctypedef failure_callback_resource_adaptor[out_of_memory] \
303+
failure_callback_resource_adaptor_oom
304+
305+
# The make_device_async_resource_ref template (declared above) also covers
306+
# failure_callback_resource_adaptor_oom; just declare the overload here
307+
# since the typedef is only available after the class is declared.
308+
cdef extern from *:
309+
"""
310+
// already defined above via template
311+
"""
312+
optional[device_async_resource_ref] make_device_async_resource_ref(
313+
failure_callback_resource_adaptor_oom&) except +
314+
265315
cdef extern from "rmm/mr/prefetch_resource_adaptor.hpp" \
266316
namespace "rmm::mr" nogil:
267-
cdef cppclass prefetch_resource_adaptor(device_memory_resource):
268-
prefetch_resource_adaptor(device_memory_resource* upstream_mr) except +
317+
cdef cppclass prefetch_resource_adaptor:
318+
prefetch_resource_adaptor(
319+
device_async_resource_ref upstream_mr) except +
Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
3-
from rmm.librmm.memory_resource cimport device_memory_resource
3+
from rmm.librmm.memory_resource cimport device_async_resource_ref
44

55

66
cdef extern from "rmm/mr/per_device_resource.hpp" namespace "rmm" nogil:
@@ -13,13 +13,9 @@ cdef extern from "rmm/mr/per_device_resource.hpp" namespace "rmm" nogil:
1313

1414
cdef extern from "rmm/mr/per_device_resource.hpp" \
1515
namespace "rmm::mr" nogil:
16-
cdef device_memory_resource* set_current_device_resource(
17-
device_memory_resource* new_mr
16+
cdef void set_current_device_resource_ref(
17+
device_async_resource_ref new_mr
1818
)
19-
cdef device_memory_resource* get_current_device_resource()
20-
cdef device_memory_resource* set_per_device_resource(
21-
cuda_device_id id, device_memory_resource* new_mr
22-
)
23-
cdef device_memory_resource* get_per_device_resource (
24-
cuda_device_id id
19+
cdef void set_per_device_resource_ref(
20+
cuda_device_id id, device_async_resource_ref new_mr
2521
)

0 commit comments

Comments
 (0)