@@ -80,12 +80,8 @@ class device_resources_snmg : public device_resources {
8080 device_resources_snmg& operator =(device_resources_snmg&&) = delete ;
8181 ~device_resources_snmg ()
8282 {
83- // Restore original device memory resources
84- if (!device_original_mrs_.empty ()) {
85- for (const auto & [device_id, original_mr] : device_original_mrs_) {
86- rmm::cuda_device_id id (device_id);
87- rmm::mr::set_per_device_resource (id, original_mr);
88- }
83+ for (int device_id : pool_device_ids_) {
84+ rmm::mr::reset_per_device_resource_ref (rmm::cuda_device_id{device_id});
8985 }
9086 }
9187
@@ -94,33 +90,27 @@ class device_resources_snmg : public device_resources {
9490 */
9591 void set_memory_pool (int percent_of_free_memory)
9692 {
97- // Protect against repeated calls - restore original resources and clear pools
9893 if (!per_device_pools_.empty ()) {
99- for (const auto & [device_id, original_mr] : device_original_mrs_) {
100- rmm::cuda_device_id id (device_id);
101- rmm::mr::set_per_device_resource (id, original_mr);
94+ for (int device_id : pool_device_ids_) {
95+ rmm::mr::reset_per_device_resource_ref (rmm::cuda_device_id{device_id});
10296 }
10397 per_device_pools_.clear ();
104- device_original_mrs_ .clear ();
98+ pool_device_ids_ .clear ();
10599 }
106100
107101 int world_size = raft::resource::get_num_ranks (*this );
108102 for (int rank = 0 ; rank < world_size; rank++) {
109103 const raft::resources& dev_res = raft::resource::set_current_device_to_rank (*this , rank);
110104
111- // Get the actual device ID for this rank
112105 int device_id = raft::resource::get_device_id (dev_res);
106+ pool_device_ids_.push_back (device_id);
113107
114- // Store the original memory resource before replacing it
115- auto old_mr = rmm::mr::get_current_device_resource ();
116- device_original_mrs_.push_back ({device_id, old_mr});
117-
118- // create a pool memory resource for each device
119108 per_device_pools_.push_back (
120109 std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
121- old_mr, rmm::percent_of_free_device_memory (percent_of_free_memory)));
122- rmm::cuda_device_id id (device_id);
123- rmm::mr::set_per_device_resource (id, per_device_pools_.back ().get ());
110+ rmm::mr::get_current_device_resource_ref (),
111+ rmm::percent_of_free_device_memory (percent_of_free_memory)));
112+ rmm::mr::set_per_device_resource_ref (rmm::cuda_device_id{device_id},
113+ *per_device_pools_.back ());
124114 }
125115 RAFT_CUDA_TRY (cudaSetDevice (main_gpu_id_));
126116 }
@@ -163,7 +153,7 @@ class device_resources_snmg : public device_resources {
163153 int main_gpu_id_;
164154 std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>>
165155 per_device_pools_;
166- std::vector<std::pair< int , rmm::mr::device_memory_resource*>> device_original_mrs_ ;
156+ std::vector<int > pool_device_ids_ ;
167157}; // class device_resources_snmg
168158
169159} // namespace raft
0 commit comments