Skip to content

Commit 856395e

Browse files
committed
Remove device_memory_resource inheritance from all resources and adaptors
Remove the device_memory_resource virtual base class inheritance from all production memory resources, adaptors, and stream_ordered_memory_resource. Resources now derive publicly from cuda::mr::shared_resource<Impl> (for stateful/adaptor types) or stand alone with direct CCCL concept methods (for stateless types). The legacy do_allocate/do_deallocate/do_is_equal virtual overrides and pointer-based per-device-resource APIs are removed. stream_ordered_memory_resource provides allocate/deallocate/allocate_sync/ deallocate_sync directly instead of through the DMR virtual dispatch. All 103 C++ tests and 1165 Python tests pass.
1 parent 20f2b6a commit 856395e

File tree

65 files changed

+354
-1887
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+354
-1887
lines changed

cpp/include/rmm/cuda_stream.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -98,6 +98,13 @@ class cuda_stream {
9898
*/
9999
operator cuda_stream_view() const;
100100

101+
/**
102+
* @brief Implicit conversion to cuda::stream_ref
103+
*
104+
* @return A stream_ref of the owned stream
105+
*/
106+
operator cuda::stream_ref() const;
107+
101108
/**
102109
* @brief Synchronize the owned CUDA stream.
103110
*

cpp/include/rmm/detail/cccl_adaptors.hpp

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,34 @@ inline constexpr bool is_specialization_of_v = false;
3232
template <template <class...> class Template, class... Args>
3333
inline constexpr bool is_specialization_of_v<Template<Args...>, Template> = true;
3434

35+
/**
36+
* @brief For a type that publicly derives from shared_resource<Impl>, extracts the
37+
* shared_resource<Impl> base type and provides a reference cast.
38+
*
39+
* CCCL's basic_any-based resource_ref can type-erase shared_resource<Impl> directly but not
40+
* classes that inherit from it. This helper extracts the base and casts to it.
41+
*/
42+
template <typename T, typename = void>
43+
struct shared_resource_cast {
44+
static constexpr bool value = false;
45+
};
46+
47+
template <typename T>
48+
struct shared_resource_cast<
49+
T,
50+
std::void_t<
51+
decltype(std::declval<std::remove_cv_t<T>&>().get()),
52+
std::enable_if_t<std::is_base_of_v<cuda::mr::shared_resource<std::remove_reference_t<
53+
decltype(std::declval<T&>().get())>>,
54+
std::remove_cv_t<T>> and
55+
not is_specialization_of_v<std::remove_cv_t<T>, cuda::mr::shared_resource>>>> {
56+
static constexpr bool value = true;
57+
using impl_type = std::remove_reference_t<decltype(std::declval<T&>().get())>;
58+
using base_type = cuda::mr::shared_resource<impl_type>;
59+
60+
static base_type& cast(T& ref) noexcept { return static_cast<base_type&>(ref); }
61+
};
62+
3563
// Forward declarations for use in enable_if constraints
3664
template <typename ResourceType>
3765
class cccl_resource_ref;
@@ -146,14 +174,32 @@ class cccl_resource_ref {
146174
{
147175
}
148176

177+
/**
178+
* @brief Construct a ref from a shared_resource-derived type.
179+
*
180+
* CCCL's basic_any-based resource_ref can type-erase shared_resource<T> directly but not
181+
* types that publicly inherit from it. This constructor casts to the shared_resource base.
182+
*
183+
* @tparam OtherResourceType A type that publicly derives from shared_resource<Impl>
184+
* @param other The shared_resource-derived resource to construct a ref from
185+
*/
186+
template <typename OtherResourceType,
187+
std::enable_if_t<shared_resource_cast<OtherResourceType>::value>* = nullptr>
188+
cccl_resource_ref(OtherResourceType& other)
189+
: view_{}, ref_{ResourceType{shared_resource_cast<OtherResourceType>::cast(other)}}
190+
{
191+
}
192+
149193
/**
150194
* @brief Construct a ref from a resource.
151195
*
152196
* This constructor accepts CCCL resource types but NOT CCCL resource_ref types,
153-
* our own wrapper types, or device_memory_resource derived types. The exclusions
154-
* are checked FIRST to prevent recursive constraint satisfaction.
197+
* our own wrapper types, device_memory_resource derived types, or
198+
* shared_resource-derived types (handled by dedicated constructor above).
199+
* The exclusions are checked FIRST to prevent recursive constraint satisfaction.
155200
*
156-
* @tparam OtherResourceType A CCCL resource type (not a resource_ref, wrapper, or DMR)
201+
* @tparam OtherResourceType A CCCL resource type (not a resource_ref, wrapper, DMR,
202+
* or shared_resource)
157203
* @param other The resource to construct a ref from
158204
*/
159205
template <typename OtherResourceType,
@@ -165,6 +211,7 @@ class cccl_resource_ref {
165211
::rmm::detail::cccl_resource_ref> and
166212
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
167213
::rmm::detail::cccl_async_resource_ref> and
214+
not shared_resource_cast<OtherResourceType>::value and
168215
not std::is_base_of_v<rmm::mr::device_memory_resource,
169216
std::remove_cv_t<OtherResourceType>> and
170217
cuda::mr::synchronous_resource<OtherResourceType>>* = nullptr>
@@ -400,15 +447,32 @@ class cccl_async_resource_ref {
400447
{
401448
}
402449

450+
/**
451+
* @brief Construct a ref from a shared_resource-derived type.
452+
*
453+
* CCCL's basic_any-based resource_ref can type-erase shared_resource<T> directly but not
454+
* types that publicly inherit from it. This constructor casts to the shared_resource base.
455+
*
456+
* @tparam OtherResourceType A type that publicly derives from shared_resource<Impl>
457+
* @param other The shared_resource-derived resource to construct a ref from
458+
*/
459+
template <typename OtherResourceType,
460+
std::enable_if_t<shared_resource_cast<OtherResourceType>::value>* = nullptr>
461+
cccl_async_resource_ref(OtherResourceType& other)
462+
: view_{}, ref_{ResourceType{shared_resource_cast<OtherResourceType>::cast(other)}}
463+
{
464+
}
465+
403466
/**
404467
* @brief Construct a ref from a resource.
405468
*
406469
* This constructor accepts CCCL resource types but NOT CCCL resource_ref types,
407-
* our own wrapper types, any_resource types, or device_memory_resource derived types.
470+
* our own wrapper types, any_resource types, device_memory_resource derived types,
471+
* or shared_resource-derived types (handled by dedicated constructor above).
408472
* The exclusions are checked FIRST to prevent recursive constraint satisfaction.
409473
*
410-
* @tparam OtherResourceType A CCCL resource type (not a resource_ref, wrapper, any_resource, or
411-
* DMR)
474+
* @tparam OtherResourceType A CCCL resource type (not a resource_ref, wrapper, any_resource,
475+
* DMR, or shared_resource)
412476
* @param other The resource to construct a ref from
413477
*/
414478
template <
@@ -422,6 +486,7 @@ class cccl_async_resource_ref {
422486
::rmm::detail::cccl_resource_ref> and
423487
not is_specialization_of_v<std::remove_cv_t<OtherResourceType>,
424488
::rmm::detail::cccl_async_resource_ref> and
489+
not shared_resource_cast<OtherResourceType>::value and
425490
not std::is_base_of_v<rmm::mr::device_memory_resource,
426491
std::remove_cv_t<OtherResourceType>> and
427492
cuda::mr::resource<OtherResourceType>>* = nullptr>

cpp/include/rmm/mr/aligned_resource_adaptor.hpp

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
#pragma once
66

77
#include <rmm/aligned.hpp>
8-
#include <rmm/cuda_stream_view.hpp>
98
#include <rmm/detail/export.hpp>
109
#include <rmm/mr/detail/aligned_resource_adaptor_impl.hpp>
11-
#include <rmm/mr/device_memory_resource.hpp>
1210
#include <rmm/resource_ref.hpp>
1311

1412
#include <cuda/memory_resource>
@@ -33,40 +31,10 @@ namespace mr {
3331
* `cuda::mr::shared_resource`.
3432
*/
3533
class RMM_EXPORT aligned_resource_adaptor
36-
: public device_memory_resource,
37-
private cuda::mr::shared_resource<detail::aligned_resource_adaptor_impl> {
34+
: public cuda::mr::shared_resource<detail::aligned_resource_adaptor_impl> {
3835
using shared_base = cuda::mr::shared_resource<detail::aligned_resource_adaptor_impl>;
3936

4037
public:
41-
// Begin legacy device_memory_resource compatibility layer
42-
using device_memory_resource::allocate;
43-
using device_memory_resource::allocate_sync;
44-
using device_memory_resource::deallocate;
45-
using device_memory_resource::deallocate_sync;
46-
47-
/**
48-
* @brief Compare two adaptors for equality (shared-impl identity).
49-
*
50-
* @param other The other adaptor to compare against.
51-
* @return true if both adaptors share the same underlying impl.
52-
*/
53-
[[nodiscard]] bool operator==(aligned_resource_adaptor const& other) const noexcept
54-
{
55-
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(other);
56-
}
57-
58-
/**
59-
* @brief Compare two adaptors for inequality.
60-
*
61-
* @param other The other adaptor to compare against.
62-
* @return true if the adaptors do not share the same underlying impl.
63-
*/
64-
[[nodiscard]] bool operator!=(aligned_resource_adaptor const& other) const noexcept
65-
{
66-
return !(*this == other);
67-
}
68-
// End legacy device_memory_resource compatibility layer
69-
7038
/**
7139
* @brief Enables the `cuda::mr::device_accessible` property
7240
*/
@@ -102,15 +70,6 @@ class RMM_EXPORT aligned_resource_adaptor
10270
* @briefreturn{rmm::device_async_resource_ref to the upstream resource}
10371
*/
10472
[[nodiscard]] device_async_resource_ref get_upstream_resource() const noexcept;
105-
106-
// Begin legacy device_memory_resource compatibility layer
107-
private:
108-
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override;
109-
110-
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) noexcept override;
111-
112-
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override;
113-
// End legacy device_memory_resource compatibility layer
11473
};
11574

11675
static_assert(cuda::mr::resource_with<aligned_resource_adaptor, cuda::mr::device_accessible>,

cpp/include/rmm/mr/arena_memory_resource.hpp

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
*/
55
#pragma once
66

7-
#include <rmm/cuda_stream_view.hpp>
87
#include <rmm/detail/export.hpp>
98
#include <rmm/mr/detail/arena_memory_resource_impl.hpp>
10-
#include <rmm/mr/device_memory_resource.hpp>
119
#include <rmm/resource_ref.hpp>
1210

1311
#include <cuda/memory_resource>
@@ -63,40 +61,10 @@ namespace mr {
6361
* \see https://github.com/google/tcmalloc
6462
*/
6563
class RMM_EXPORT arena_memory_resource
66-
: public device_memory_resource,
67-
private cuda::mr::shared_resource<detail::arena_memory_resource_impl> {
64+
: public cuda::mr::shared_resource<detail::arena_memory_resource_impl> {
6865
using shared_base = cuda::mr::shared_resource<detail::arena_memory_resource_impl>;
6966

7067
public:
71-
// Begin legacy device_memory_resource compatibility layer
72-
using device_memory_resource::allocate;
73-
using device_memory_resource::allocate_sync;
74-
using device_memory_resource::deallocate;
75-
using device_memory_resource::deallocate_sync;
76-
77-
/**
78-
* @brief Compare two resources for equality (shared-impl identity).
79-
*
80-
* @param other The other arena_memory_resource to compare against.
81-
* @return true if both resources share the same underlying state.
82-
*/
83-
[[nodiscard]] bool operator==(arena_memory_resource const& other) const noexcept
84-
{
85-
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(other);
86-
}
87-
88-
/**
89-
* @brief Compare two resources for inequality.
90-
*
91-
* @param other The other arena_memory_resource to compare against.
92-
* @return true if the resources do not share the same underlying state.
93-
*/
94-
[[nodiscard]] bool operator!=(arena_memory_resource const& other) const noexcept
95-
{
96-
return !(*this == other);
97-
}
98-
// End legacy device_memory_resource compatibility layer
99-
10068
/**
10169
* @brief Enables the `cuda::mr::device_accessible` property
10270
*/
@@ -118,15 +86,6 @@ class RMM_EXPORT arena_memory_resource
11886
bool dump_log_on_failure = false);
11987

12088
~arena_memory_resource() = default;
121-
122-
// Begin legacy device_memory_resource compatibility layer
123-
private:
124-
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override;
125-
126-
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) noexcept override;
127-
128-
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override;
129-
// End legacy device_memory_resource compatibility layer
13089
};
13190

13291
static_assert(cuda::mr::resource_with<arena_memory_resource, cuda::mr::device_accessible>,

cpp/include/rmm/mr/binning_memory_resource.hpp

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
*/
55
#pragma once
66

7-
#include <rmm/cuda_stream_view.hpp>
87
#include <rmm/detail/export.hpp>
98
#include <rmm/mr/detail/binning_memory_resource_impl.hpp>
10-
#include <rmm/mr/device_memory_resource.hpp>
119
#include <rmm/resource_ref.hpp>
1210

1311
#include <cuda/memory_resource>
@@ -29,41 +27,11 @@ namespace mr {
2927
* This class is copyable and shares ownership of its internal state, allowing
3028
* multiple instances to safely reference the same underlying bins.
3129
*/
32-
class RMM_EXPORT binning_memory_resource final
33-
: public device_memory_resource,
34-
private cuda::mr::shared_resource<detail::binning_memory_resource_impl> {
30+
class RMM_EXPORT binning_memory_resource
31+
: public cuda::mr::shared_resource<detail::binning_memory_resource_impl> {
3532
using shared_base = cuda::mr::shared_resource<detail::binning_memory_resource_impl>;
3633

3734
public:
38-
// Begin legacy device_memory_resource compatibility layer
39-
using device_memory_resource::allocate;
40-
using device_memory_resource::allocate_sync;
41-
using device_memory_resource::deallocate;
42-
using device_memory_resource::deallocate_sync;
43-
44-
/**
45-
* @brief Equality comparison operator.
46-
*
47-
* @param other The other binning_memory_resource to compare against.
48-
* @return true if both resources share the same underlying state.
49-
*/
50-
[[nodiscard]] bool operator==(binning_memory_resource const& other) const noexcept
51-
{
52-
return static_cast<shared_base const&>(*this) == static_cast<shared_base const&>(other);
53-
}
54-
55-
/**
56-
* @brief Inequality comparison operator.
57-
*
58-
* @param other The other binning_memory_resource to compare against.
59-
* @return true if the resources do not share the same underlying state.
60-
*/
61-
[[nodiscard]] bool operator!=(binning_memory_resource const& other) const noexcept
62-
{
63-
return !(*this == other);
64-
}
65-
// End legacy device_memory_resource compatibility layer
66-
6735
/**
6836
* @brief Enables the `cuda::mr::device_accessible` property
6937
*
@@ -124,15 +92,6 @@ class RMM_EXPORT binning_memory_resource final
12492
*/
12593
void add_bin(std::size_t allocation_size,
12694
std::optional<device_async_resource_ref> bin_resource = std::nullopt);
127-
128-
// Begin legacy device_memory_resource compatibility layer
129-
private:
130-
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override;
131-
132-
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) noexcept override;
133-
134-
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override;
135-
// End legacy device_memory_resource compatibility layer
13695
};
13796

13897
static_assert(cuda::mr::resource_with<binning_memory_resource, cuda::mr::device_accessible>,

0 commit comments

Comments
 (0)