Skip to content

Commit ab55742

Browse files
d4l3kpytorchmergebot
authored andcommitted
[cca] [c10d] Refactor CUDAEventCache into separate files (pytorch#158616)
Summary: Refactored CUDAEventCache from ProcessGroupNCCL.hpp/.cpp into dedicated header and implementation files for better code organization and maintainability. Split out CUDAEventCache into: - New header file: CUDAEventCache.hpp - New implementation file: CUDAEventCache.cpp - Updated build_variables.bzl to include the new file This change improves code maintainability, readability, and follows better code organization practices. --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace) Test Plan: Verified build with: ``` buck build //caffe2/test/distributed:c10d ``` --- > Generated by [Confucius Code Assist (CCA)](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=61b9029a-636b-11f0-9d9a-f1bcc55be1ce&tab=Trace) Pull Request resolved: pytorch#158616 Approved by: https://github.com/fduwjj
1 parent 90b082e commit ab55742

File tree

6 files changed

+99
-84
lines changed

6 files changed

+99
-84
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ libtorch_cuda_distributed_extra_sources = [
738738
"torch/csrc/distributed/c10d/UCCTracing.cpp",
739739
"torch/csrc/distributed/c10d/UCCUtils.cpp",
740740
"torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
741+
"torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp",
741742
"torch/csrc/distributed/c10d/cuda/utils.cpp",
742743
"torch/csrc/distributed/c10d/cuda/StreamBlock.cu",
743744
"torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",

test/cpp/c10d/ProcessGroupNCCLTest.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,8 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) {
767767
}
768768

769769
// Test that the CUDAEventCache can be used to create CUDA events and reuse.
770-
auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
771-
auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
770+
auto event1 = c10d::CUDAEventCache::get(1)->create(true);
771+
auto event2 = c10d::CUDAEventCache::get(1)->create(false);
772772

773773
auto event1_ptr = event1.get();
774774
auto event2_ptr = event2.get();
@@ -777,14 +777,14 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) {
777777
event2 = nullptr;
778778

779779
// Test that the CUDAEventCache is indeed reused.
780-
auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(true);
781-
auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(false);
780+
auto event3 = c10d::CUDAEventCache::get(2)->create(true);
781+
auto event4 = c10d::CUDAEventCache::get(2)->create(false);
782782
// The cache has been used up, new events should be created.
783-
auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
784-
auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
783+
auto event5 = c10d::CUDAEventCache::get(1)->create(true);
784+
auto event6 = c10d::CUDAEventCache::get(1)->create(false);
785785
// The cache has been used up, new events should be created.
786-
auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true);
787-
auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false);
786+
auto event7 = c10d::CUDAEventCache::get(1)->create(true);
787+
auto event8 = c10d::CUDAEventCache::get(1)->create(false);
788788
EXPECT_NE(event1_ptr, event3.get());
789789
EXPECT_NE(event2_ptr, event4.get());
790790
EXPECT_EQ(event1_ptr, event5.get());

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -519,11 +519,9 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
519519
// DEFAULT_FLAGS = cudaEventDisableTiming.
520520
if (cudaEventCacheEnabled) {
521521
ncclStartEvent_ = enableTiming
522-
? ProcessGroupNCCL::CUDAEventCache::get(device.index())
523-
->create(enableTiming)
522+
? CUDAEventCache::get(device.index())->create(enableTiming)
524523
: nullptr;
525-
ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index())
526-
->create(enableTiming);
524+
ncclEndEvent_ = CUDAEventCache::get(device.index())->create(enableTiming);
527525
} else {
528526
ncclStartEvent_ = enableTiming
529527
? std::make_shared<at::cuda::CUDAEvent>(cudaEventDefault)
@@ -860,61 +858,6 @@ void ProcessGroupNCCL::WorkNCCL::abort() {
860858
}
861859
}
862860

863-
ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default;
864-
865-
// CUDA event is used to record the start/end of one Work.
866-
// Instead of let the CUDA event gets destroyed, we now reuse it after the Work
867-
// has been erased from workMetaList_.
868-
// This is to avoid the potential deadlock caused by CudaEventDestroy.
869-
std::shared_ptr<at::cuda::CUDAEvent> ProcessGroupNCCL::CUDAEventCache::create(
870-
bool timing) {
871-
// Register the deleter as a callback when the WorkNCCL object is destroyed.
872-
// Each deleter keeps a ref count to the cache object, so that even when
873-
// the thread that creates the cache is gone, the cache object won't be
874-
// destroyed until all the events in the cache are destroyed (ref number drops
875-
// to zero).
876-
auto deleter = [cache = shared_from_this(),
877-
timing](at::cuda::CUDAEvent* event) {
878-
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
879-
// We put the event back to the cache deque once the WorkNCCL object is
880-
// destroyed.
881-
cache->eventsArray_[timing ? 1 : 0].push_back(event);
882-
};
883-
at::cuda::CUDAEvent* event = nullptr;
884-
{
885-
std::lock_guard<std::mutex> lock(cacheMutex_);
886-
auto& events = eventsArray_[timing ? 1 : 0];
887-
// If we still have events in the cache, we reuse it. Otherwise, we create a
888-
// new one.
889-
if (!events.empty()) {
890-
event = events.front();
891-
events.pop_front();
892-
} else {
893-
event = new at::cuda::CUDAEvent(
894-
timing ? cudaEventDefault : cudaEventDisableTiming);
895-
}
896-
}
897-
return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
898-
}
899-
900-
std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> ProcessGroupNCCL::
901-
CUDAEventCache::get(at::DeviceIndex device) {
902-
// A per-thread singleton of device-to-CUDAEventCache map.
903-
// Map is needed because events cannot be reused across devices.
904-
// Per-thread ownership is needed to support multi-threaded case (instead of
905-
// multi-process case).
906-
static thread_local std::
907-
map<at::DeviceIndex, std::shared_ptr<ProcessGroupNCCL::CUDAEventCache>>
908-
cacheDeviceMap;
909-
// Check if device has already been in the map, if not, add a new entry
910-
auto it = cacheDeviceMap.find(device);
911-
if (it == cacheDeviceMap.end()) {
912-
cacheDeviceMap.emplace(
913-
device, std::make_shared<ProcessGroupNCCL::CUDAEventCache>());
914-
}
915-
return cacheDeviceMap[device];
916-
}
917-
918861
static std::atomic<size_t> process_group_id = 0;
919862

920863
constexpr const char* MULTI_DEVICE_ERROR_MSG =

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
2424
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
2525
#include <torch/csrc/distributed/c10d/Store.hpp>
26+
#include <torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp>
2627
#include <torch/csrc/distributed/c10d/logger.hpp>
2728
#include <torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp>
2829

@@ -503,23 +504,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
503504
friend class ProcessGroupNCCL;
504505
};
505506

506-
class CUDAEventCache
507-
: public std::enable_shared_from_this<ProcessGroupNCCL::CUDAEventCache> {
508-
public:
509-
CUDAEventCache();
510-
std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
511-
static std::shared_ptr<ProcessGroupNCCL::CUDAEventCache> get(
512-
at::DeviceIndex device);
513-
514-
private:
515-
std::mutex cacheMutex_;
516-
// NOTE: We intentionally store raw pointers so that
517-
// we do not attempt to destroy the event objects on process exit,
518-
// because cuda may be gone.
519-
std::array<std::deque<at::cuda::CUDAEvent*>, 2>
520-
eventsArray_; // 0 for timing=false, 1 for timing=true
521-
};
522-
523507
struct Options : Backend::Options {
524508
// NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
525509
// operations. This is only used when blockingWait_ is enabled.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include <c10/cuda/CUDAStream.h>
2+
#include <torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp>
3+
#include <map>
4+
5+
namespace c10d {
6+
7+
CUDAEventCache::CUDAEventCache() = default;
8+
9+
// CUDA event is used to record the start/end of one Work.
10+
// Instead of let the CUDA event gets destroyed, we now reuse it after the Work
11+
// has been erased from workMetaList_.
12+
// This is to avoid the potential deadlock caused by CudaEventDestroy.
13+
std::shared_ptr<at::cuda::CUDAEvent> CUDAEventCache::create(bool timing) {
14+
// Register the deleter as a callback when the WorkNCCL object is destroyed.
15+
// Each deleter keeps a ref count to the cache object, so that even when
16+
// the thread that creates the cache is gone, the cache object won't be
17+
// destroyed until all the events in the cache are destroyed (ref number drops
18+
// to zero).
19+
auto deleter = [cache = shared_from_this(),
20+
timing](at::cuda::CUDAEvent* event) {
21+
std::lock_guard<std::mutex> lock(cache->cacheMutex_);
22+
// We put the event back to the cache deque once the WorkNCCL object is
23+
// destroyed.
24+
cache->eventsArray_[timing ? 1 : 0].push_back(event);
25+
};
26+
at::cuda::CUDAEvent* event = nullptr;
27+
{
28+
std::lock_guard<std::mutex> lock(cacheMutex_);
29+
auto& events = eventsArray_[timing ? 1 : 0];
30+
// If we still have events in the cache, we reuse it. Otherwise, we create a
31+
// new one.
32+
if (!events.empty()) {
33+
event = events.front();
34+
events.pop_front();
35+
} else {
36+
event = new at::cuda::CUDAEvent(
37+
timing ? cudaEventDefault : cudaEventDisableTiming);
38+
}
39+
}
40+
return std::shared_ptr<at::cuda::CUDAEvent>(event, std::move(deleter));
41+
}
42+
43+
std::shared_ptr<CUDAEventCache> CUDAEventCache::get(at::DeviceIndex device) {
44+
// A per-thread singleton of device-to-CUDAEventCache map.
45+
// Map is needed because events cannot be reused across devices.
46+
// Per-thread ownership is needed to support multi-threaded case (instead of
47+
// multi-process case).
48+
static thread_local std::map<at::DeviceIndex, std::shared_ptr<CUDAEventCache>>
49+
cacheDeviceMap;
50+
// Check if device has already been in the map, if not, add a new entry
51+
auto it = cacheDeviceMap.find(device);
52+
if (it == cacheDeviceMap.end()) {
53+
cacheDeviceMap.emplace(device, std::make_shared<CUDAEventCache>());
54+
}
55+
return cacheDeviceMap[device];
56+
}
57+
58+
} // namespace c10d
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <array>
4+
#include <deque>
5+
#include <memory>
6+
#include <mutex>
7+
8+
#include <ATen/cuda/CUDAEvent.h>
9+
#include <c10/macros/Export.h>
10+
11+
namespace c10d {
12+
13+
class TORCH_API CUDAEventCache
14+
: public std::enable_shared_from_this<CUDAEventCache> {
15+
public:
16+
CUDAEventCache();
17+
std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
18+
static std::shared_ptr<CUDAEventCache> get(at::DeviceIndex device);
19+
20+
private:
21+
std::mutex cacheMutex_;
22+
// NOTE: We intentionally store raw pointers so that
23+
// we do not attempt to destroy the event objects on process exit,
24+
// because cuda may be gone.
25+
std::array<std::deque<at::cuda::CUDAEvent*>, 2>
26+
eventsArray_; // 0 for timing=false, 1 for timing=true
27+
};
28+
29+
} // namespace c10d

0 commit comments

Comments
 (0)