Skip to content

Commit 18a7a04

Browse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] Add NCCL memory allocator (pytorch#145675)
This PR implements a small UI improvement over pytorch#133603. It prepares a NCCL memory allocator in torch cpp and then pybind's it out, so that user can directly use it. UI: ``` pool = torch.cuda.MemPool(backend.mem_allocator) with torch.cuda.use_mem_pool(pool): tensor = torch.arange(1024 * 1024 * 2, device=device) ``` Pull Request resolved: pytorch#145675 Approved by: https://github.com/syed-ahmed, https://github.com/wconstab
1 parent b60120d commit 18a7a04

File tree

7 files changed

+64
-43
lines changed

7 files changed

+64
-43
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
TEST_WITH_ROCM,
6868
TestCase,
6969
)
70-
from torch.utils.cpp_extension import load_inline
7170

7271

7372
if TEST_WITH_DEV_DBG_ASAN:
@@ -3104,40 +3103,6 @@ def test_nccl_timeout(self):
31043103

31053104

31063105
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
3107-
def createNcclAllocator(self):
3108-
nccl_allocator_source = """
3109-
#include <torch/extension.h>
3110-
#include <nccl.h>
3111-
#include <iostream>
3112-
3113-
extern "C" {
3114-
3115-
// Note that windows needs __declspec(dllexport): https://stackoverflow.com/a/24575865
3116-
C10_EXPORT void* nccl_alloc(size_t size, int device, void* stream) {
3117-
std::cout << "Using ncclMemAlloc" << std::endl;
3118-
void* ptr;
3119-
ncclResult_t err = ncclMemAlloc(&ptr, size);
3120-
return ptr;
3121-
}
3122-
3123-
C10_EXPORT void nccl_free(void* ptr, size_t size, int device, void* stream) {
3124-
std::cout << "Using ncclMemFree" << std::endl;
3125-
ncclResult_t err = ncclMemFree(ptr);
3126-
}
3127-
}
3128-
"""
3129-
nccl_allocator_libname = "nccl_allocator"
3130-
nccl_allocator = load_inline(
3131-
name=nccl_allocator_libname,
3132-
cpp_sources=nccl_allocator_source,
3133-
with_cuda=True,
3134-
extra_ldflags=["-lnccl"],
3135-
is_python_module=False,
3136-
keep_intermediates=False,
3137-
verbose=True,
3138-
)
3139-
return nccl_allocator
3140-
31413106
def setUp(self):
31423107
super().setUp()
31433108
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
@@ -3172,13 +3137,9 @@ def test_nccl_user_buffer_registration(self):
31723137
torch.cuda.set_device(self.rank)
31733138
pg = c10d.distributed_c10d._get_default_group()
31743139
backend = pg._get_backend(torch.device(device))
3175-
allocator_path = self.createNcclAllocator()
3176-
allocator = torch.cuda.memory.CUDAPluggableAllocator(
3177-
allocator_path,
3178-
"nccl_alloc",
3179-
"nccl_free",
3180-
)
3181-
pool = torch.cuda.MemPool(allocator.allocator())
3140+
3141+
# Use NCCL memory allocator
3142+
pool = torch.cuda.MemPool(backend.mem_allocator)
31823143

31833144
# allocate memory with ncclMemAlloc
31843145
with torch.cuda.use_mem_pool(pool):

torch/_C/_distributed_c10d.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class Backend:
296296
def _set_sequence_number_for_group(self) -> None: ...
297297
def _set_default_timeout(self, timeout: timedelta) -> None: ...
298298
def get_error(self) -> ErrorType: ...
299+
@property
300+
def mem_allocator(self) -> Any: ...
299301

300302
class ProcessGroup:
301303
class BackendType(Enum):

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include <ATen/ATen.h>
8+
#include <c10/core/Allocator.h>
89
#include <c10/macros/Macros.h>
910

1011
#include <torch/csrc/distributed/c10d/Types.hpp>
@@ -409,6 +410,13 @@ class TORCH_API Backend : public torch::CustomClassHolder {
409410
c10::str("Backend ", getBackendName(), " does not support getError"));
410411
}
411412

413+
virtual std::shared_ptr<c10::Allocator> getMemAllocator() {
414+
TORCH_CHECK(
415+
false,
416+
c10::str(
417+
"Backend ", getBackendName(), " does not support getMemAllocator"));
418+
}
419+
412420
protected:
413421
// Implementations of this interface need to call this to setup
414422
// appropriate logging etc.

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ static_assert(
8383
#define NCCL_HAS_COMM_REGISTER
8484
#endif
8585

86+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 19, 0)
87+
#define NCCL_HAS_MEM_ALLOC
88+
#endif
89+
8690
// Macro to throw on a non-successful NCCL return value.
8791
#define C10D_NCCL_CHECK(cmd, failureReason) \
8892
do { \

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <c10/util/WaitCounter.h>
2121
#include <c10/util/irange.h>
2222
#include <c10/util/thread_name.h>
23+
#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
2324
#include <torch/csrc/cuda/nccl.h>
2425
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
2526
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
@@ -5249,6 +5250,47 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
52495250
avoidRecordStreams);
52505251
}
52515252

5253+
// Create a memory allocator for NCCL. This allocator is used to allocate memory
5254+
// that supports NVLink Sharp functionality. This allocator is later pybinded to
5255+
// python, so that users can use it to create MemPool. For example:
5256+
// >>> pool = torch.cuda.MemPool(backend.mem_allocator)
5257+
5258+
// Allocate function
5259+
static void* _ncclMemAlloc(size_t size, int device, void* stream) {
5260+
#ifndef NCCL_HAS_MEM_ALLOC
5261+
TORCH_CHECK(
5262+
false, "NCCL mem allocator is not supported in this NCCL version");
5263+
#endif // NCCL_HAS_MEM_ALLOC
5264+
5265+
LOG(INFO) << "NCCL mem allocator: allocating " << size << " bytes";
5266+
at::cuda::OptionalCUDAGuard gpuGuard(device);
5267+
void* ptr = nullptr;
5268+
TORCH_CHECK(ncclMemAlloc(&ptr, size) == ncclSuccess, "ncclMemAlloc failed");
5269+
return ptr;
5270+
}
5271+
5272+
// Free function
5273+
static void _ncclMemFree(void* ptr, size_t size, int device, void* stream) {
5274+
#ifndef NCCL_HAS_MEM_ALLOC
5275+
TORCH_CHECK(
5276+
false, "NCCL mem allocator is not supported in this NCCL version");
5277+
#endif // NCCL_HAS_MEM_ALLOC
5278+
5279+
LOG(INFO) << "NCCL mem allocator: freeing " << size << " bytes";
5280+
at::cuda::OptionalCUDAGuard gpuGuard(device);
5281+
TORCH_CHECK(ncclMemFree(ptr) == ncclSuccess, "ncclMemFree failed");
5282+
}
5283+
5284+
// Create a `CUDAPluggableAllocator` that uses the above functions.
5285+
std::shared_ptr<c10::Allocator> ProcessGroupNCCL::getMemAllocator() {
5286+
C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.getMemAllocator");
5287+
static std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
5288+
ncclMemAllocator =
5289+
torch::cuda::CUDAPluggableAllocator::createCustomAllocator(
5290+
_ncclMemAlloc, _ncclMemFree);
5291+
return ncclMemAllocator;
5292+
}
5293+
52525294
} // namespace c10d
52535295

52545296
#endif // USE_C10D_NCCL

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
768768

769769
ErrorType getError() override;
770770

771+
std::shared_ptr<c10::Allocator> getMemAllocator() override;
772+
771773
// Performs NCCL user buffer registration for all buffers in
772774
// the given MemPool
773775
void registerMemPool(c10::cuda::MemPool* pool);

torch/csrc/distributed/c10d/init.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2765,7 +2765,9 @@ The hook must have the following signature:
27652765
.def(
27662766
"_end_coalescing",
27672767
&::c10d::Backend::endCoalescing,
2768-
py::call_guard<py::gil_scoped_release>());
2768+
py::call_guard<py::gil_scoped_release>())
2769+
.def_property_readonly(
2770+
"mem_allocator", &::c10d::Backend::getMemAllocator);
27692771

27702772
// base Backend::Options binding
27712773
// TODO: Maybe we can consider how to merge this with

0 commit comments

Comments
 (0)