Skip to content
Merged
55 changes: 47 additions & 8 deletions cpp/include/raft/core/device_resources_snmg.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -9,6 +9,12 @@
#include <raft/core/resource/multi_gpu.hpp>
#include <raft/core/resource/resource_types.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device_memory_resource.hpp>
#include <rmm/mr/per_device_resource.hpp>
#include <rmm/mr/pool_memory_resource.hpp>

#include <memory>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -72,19 +78,49 @@ class device_resources_snmg : public device_resources {

device_resources_snmg(device_resources_snmg&&) = delete;
device_resources_snmg& operator=(device_resources_snmg&&) = delete;
~device_resources_snmg() {};
~device_resources_snmg()
{
// Restore original device memory resources
if (!device_original_mrs_.empty()) {
for (const auto& [device_id, original_mr] : device_original_mrs_) {
rmm::cuda_device_id id(device_id);
rmm::mr::set_per_device_resource(id, original_mr);
}
}
}

/**
* @brief Set a memory pool on all GPUs of the multi-gpu world
*/
void set_memory_pool(int percent_of_free_memory) const
void set_memory_pool(int percent_of_free_memory)
{
// Protect against repeated calls - restore original resources and clear pools
if (!per_device_pools_.empty()) {
for (const auto& [device_id, original_mr] : device_original_mrs_) {
rmm::cuda_device_id id(device_id);
rmm::mr::set_per_device_resource(id, original_mr);
}
per_device_pools_.clear();
device_original_mrs_.clear();
}

int world_size = raft::resource::get_num_ranks(*this);
for (int gpu_id = 0; gpu_id < world_size; gpu_id++) {
const raft::resources& dev_res = raft::resource::set_current_device_to_rank(*this, gpu_id);
// check limit for each device
size_t limit = rmm::percent_of_free_device_memory(percent_of_free_memory);
raft::resource::set_workspace_to_pool_resource(dev_res, limit);
for (int rank = 0; rank < world_size; rank++) {
const raft::resources& dev_res = raft::resource::set_current_device_to_rank(*this, rank);

// Get the actual device ID for this rank
int device_id = raft::resource::get_device_id(dev_res);

// Store the original memory resource before replacing it
auto old_mr = rmm::mr::get_current_device_resource();
device_original_mrs_.push_back({device_id, old_mr});

// create a pool memory resource for each device
per_device_pools_.push_back(
std::make_unique<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>(
old_mr, rmm::percent_of_free_device_memory(percent_of_free_memory)));
rmm::cuda_device_id id(device_id);
rmm::mr::set_per_device_resource(id, per_device_pools_.back().get());
}
RAFT_CUDA_TRY(cudaSetDevice(main_gpu_id_));
}
Expand Down Expand Up @@ -125,6 +161,9 @@ class device_resources_snmg : public device_resources {
}
}
int main_gpu_id_;
std::vector<std::unique_ptr<rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>>>
per_device_pools_;
std::vector<std::pair<int, rmm::mr::device_memory_resource*>> device_original_mrs_;
}; // class device_resources_snmg

} // namespace raft