|
20 | 20 | #include <c10/util/WaitCounter.h> |
21 | 21 | #include <c10/util/irange.h> |
22 | 22 | #include <c10/util/thread_name.h> |
| 23 | +#include <torch/csrc/cuda/CUDAPluggableAllocator.h> |
23 | 24 | #include <torch/csrc/cuda/nccl.h> |
24 | 25 | #include <torch/csrc/distributed/c10d/FlightRecorder.hpp> |
25 | 26 | #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
@@ -5249,6 +5250,47 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base( |
5249 | 5250 | avoidRecordStreams); |
5250 | 5251 | } |
5251 | 5252 |
|
| 5253 | +// Create a memory allocator for NCCL. This allocator is used to allocate memory |
| 5254 | +// that supports NVLink Sharp functionality. This allocator is later pybinded to |
| 5255 | +// python, so that users can use it to create MemPool. For example: |
| 5256 | +// >>> pool = torch.cuda.MemPool(backend.mem_allocator) |
| 5257 | + |
| 5258 | +// Allocate function |
| 5259 | +static void* _ncclMemAlloc(size_t size, int device, void* stream) { |
| 5260 | +#ifndef NCCL_HAS_MEM_ALLOC |
| 5261 | + TORCH_CHECK( |
| 5262 | + false, "NCCL mem allocator is not supported in this NCCL version"); |
| 5263 | +#endif // NCCL_HAS_MEM_ALLOC |
| 5264 | + |
| 5265 | + LOG(INFO) << "NCCL mem allocator: allocating " << size << " bytes"; |
| 5266 | + at::cuda::OptionalCUDAGuard gpuGuard(device); |
| 5267 | + void* ptr = nullptr; |
| 5268 | + TORCH_CHECK(ncclMemAlloc(&ptr, size) == ncclSuccess, "ncclMemAlloc failed"); |
| 5269 | + return ptr; |
| 5270 | +} |
| 5271 | + |
| 5272 | +// Free function |
| 5273 | +static void _ncclMemFree(void* ptr, size_t size, int device, void* stream) { |
| 5274 | +#ifndef NCCL_HAS_MEM_ALLOC |
| 5275 | + TORCH_CHECK( |
| 5276 | + false, "NCCL mem allocator is not supported in this NCCL version"); |
| 5277 | +#endif // NCCL_HAS_MEM_ALLOC |
| 5278 | + |
| 5279 | + LOG(INFO) << "NCCL mem allocator: freeing " << size << " bytes"; |
| 5280 | + at::cuda::OptionalCUDAGuard gpuGuard(device); |
| 5281 | + TORCH_CHECK(ncclMemFree(ptr) == ncclSuccess, "ncclMemFree failed"); |
| 5282 | +} |
| 5283 | + |
| 5284 | +// Create a `CUDAPluggableAllocator` that uses the above functions. |
| 5285 | +std::shared_ptr<c10::Allocator> ProcessGroupNCCL::getMemAllocator() { |
| 5286 | + C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.getMemAllocator"); |
| 5287 | + static std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
| 5288 | + ncclMemAllocator = |
| 5289 | + torch::cuda::CUDAPluggableAllocator::createCustomAllocator( |
| 5290 | + _ncclMemAlloc, _ncclMemFree); |
| 5291 | + return ncclMemAllocator; |
| 5292 | +} |
| 5293 | + |
5252 | 5294 | } // namespace c10d |
5253 | 5295 |
|
5254 | 5296 | #endif // USE_C10D_NCCL |
0 commit comments