Skip to content

Commit 662443e

Browse files
authored
Use the new reference-based RMM API to set RMM pools in device_resources_snmg (#2972)
This PR updates the `device_resources_snmg` utility so that it uses the new reference-based RMM API to set RMM pools. Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Jinsol Park (https://github.com/jinsolp) - Artem M. Chirkin (https://github.com/achirkin) URL: #2972
1 parent 7b03b3f commit 662443e

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

cpp/include/raft/core/device_resources_snmg.hpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,8 @@ class device_resources_snmg : public device_resources {
8080
device_resources_snmg& operator=(device_resources_snmg&&) = delete;
8181
~device_resources_snmg()
8282
{
83-
// Restore original device memory resources
84-
if (!device_original_mrs_.empty()) {
85-
for (const auto& [device_id, original_mr] : device_original_mrs_) {
86-
rmm::cuda_device_id id(device_id);
87-
rmm::mr::set_per_device_resource(id, original_mr);
88-
}
83+
for (int device_id : pool_device_ids_) {
84+
rmm::mr::reset_per_device_resource_ref(rmm::cuda_device_id{device_id});
8985
}
9086
}
9187

@@ -94,33 +90,27 @@ class device_resources_snmg : public device_resources {
9490
*/
9591
void set_memory_pool(int percent_of_free_memory)
9692
{
97-
// Protect against repeated calls - restore original resources and clear pools
9893
if (!per_device_pools_.empty()) {
99-
for (const auto& [device_id, original_mr] : device_original_mrs_) {
100-
rmm::cuda_device_id id(device_id);
101-
rmm::mr::set_per_device_resource(id, original_mr);
94+
for (int device_id : pool_device_ids_) {
95+
rmm::mr::reset_per_device_resource_ref(rmm::cuda_device_id{device_id});
10296
}
10397
per_device_pools_.clear();
104-
device_original_mrs_.clear();
98+
pool_device_ids_.clear();
10599
}
106100

107101
int world_size = raft::resource::get_num_ranks(*this);
108102
for (int rank = 0; rank < world_size; rank++) {
109103
const raft::resources& dev_res = raft::resource::set_current_device_to_rank(*this, rank);
110104

111-
// Get the actual device ID for this rank
112105
int device_id = raft::resource::get_device_id(dev_res);
106+
pool_device_ids_.push_back(device_id);
113107

114-
// Store the original memory resource before replacing it
115-
auto old_mr = rmm::mr::get_current_device_resource();
116-
device_original_mrs_.push_back({device_id, old_mr});
117-
118-
// create a pool memory resource for each device
119108
per_device_pools_.push_back(
120109
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
121-
old_mr, rmm::percent_of_free_device_memory(percent_of_free_memory)));
122-
rmm::cuda_device_id id(device_id);
123-
rmm::mr::set_per_device_resource(id, per_device_pools_.back().get());
110+
rmm::mr::get_current_device_resource_ref(),
111+
rmm::percent_of_free_device_memory(percent_of_free_memory)));
112+
rmm::mr::set_per_device_resource_ref(rmm::cuda_device_id{device_id},
113+
*per_device_pools_.back());
124114
}
125115
RAFT_CUDA_TRY(cudaSetDevice(main_gpu_id_));
126116
}
@@ -163,7 +153,7 @@ class device_resources_snmg : public device_resources {
163153
int main_gpu_id_;
164154
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>>
165155
per_device_pools_;
166-
std::vector<std::pair<int, rmm::mr::device_memory_resource*>> device_original_mrs_;
156+
std::vector<int> pool_device_ids_;
167157
}; // class device_resources_snmg
168158

169159
} // namespace raft

0 commit comments

Comments
 (0)