Skip to content

Commit 25ca05e

Browse files
kwen2501pytorchmergebot
authored andcommitted
[PGNCCL] Correct some ifdef's (pytorch#145893)
`create` function supporting `ncclConfig_t` should be wrapped inside `NCCL_HAS_CONFIG` instead of `NCCL_HAS_COMM_NONBLOCKING` Pull Request resolved: pytorch#145893 Approved by: https://github.com/c-p-i-o
1 parent 73dde45 commit 25ca05e

File tree

5 files changed

+17
-12
lines changed

5 files changed

+17
-12
lines changed

torch/csrc/distributed/c10d/NCCLUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
6262
return comm;
6363
}
6464

65-
#ifdef NCCL_HAS_COMM_NONBLOCKING
65+
#ifdef NCCL_HAS_CONFIG
6666
std::shared_ptr<NCCLComm> NCCLComm::create(
6767
int numRanks,
6868
int rank,
@@ -87,7 +87,7 @@ std::shared_ptr<NCCLComm> NCCLComm::create(
8787
comm->initialized_ = !comm->nonBlocking_;
8888
return comm;
8989
}
90-
#endif
90+
#endif // NCCL_HAS_CONFIG
9191

9292
ncclComm_t NCCLComm::getNcclComm() {
9393
LockType lock(mutex_);

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,14 @@ static_assert(
6868
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
6969
#endif
7070

71+
// Note: the first version that supports ncclConfig_t is 2.14. Here we
72+
// fast-forward the version requirement to 2.17 where ncclConfig_t has CTA and
73+
// CGA fields because they have already been pybinded out.
7174
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
7275
(NCCL_MINOR >= 17)
73-
#define NCCL_HAS_COMM_CTA_CGA
76+
#define NCCL_HAS_CONFIG
7477
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
75-
#define NCCL_HAS_COMM_CTA_CGA
78+
#define NCCL_HAS_CONFIG
7679
#endif
7780

7881
#if defined(NCCL_REGISTRATION_SUPPORTED) || \
@@ -230,21 +233,23 @@ class NCCLComm {
230233
ncclUniqueId commId,
231234
at::DeviceIndex deviceIndex);
232235

233-
#ifdef NCCL_HAS_COMM_NONBLOCKING
236+
#ifdef NCCL_HAS_CONFIG
234237
static std::shared_ptr<NCCLComm> create(
235238
int numRanks,
236239
int rank,
237240
ncclUniqueId commId,
238241
at::DeviceIndex deviceIndex,
239242
ncclConfig_t& config);
243+
#endif // NCCL_HAS_CONFIG
240244

245+
#ifdef NCCL_HAS_COMM_SPLIT
241246
static std::shared_ptr<NCCLComm> split(
242247
NCCLComm* source,
243248
int color_id,
244249
int rank,
245250
ncclConfig_t& config,
246251
std::vector<uint64_t>& ranks_ull);
247-
#endif // NCCL_HAS_COMM_NONBLOCKING
252+
#endif // NCCL_HAS_COMM_SPLIT
248253

249254
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
250255
std::unordered_map<std::string, std::string> ncclCommDump();

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,12 +2704,12 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
27042704
<< timerDeltaMs << " ms";
27052705
}
27062706

2707-
#ifdef NCCL_HAS_COMM_NONBLOCKING
2707+
#ifdef NCCL_HAS_CONFIG
27082708
ncclComm =
27092709
NCCLComm::create(numRanks, rank, ncclID, deviceIndex, options_->config);
27102710
#else
27112711
ncclComm = NCCLComm::create(numRanks, rank, ncclID, deviceIndex);
2712-
#endif // NCCL_HAS_COMM_NONBLOCKING
2712+
#endif // NCCL_HAS_CONFIG
27132713
}
27142714

27152715
// Creates the NCCL streams

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
492492
// Schedule NCCL operations on high priority CUDA streams
493493
bool is_high_priority_stream;
494494

495-
#ifdef NCCL_HAS_COMM_NONBLOCKING
495+
#ifdef NCCL_HAS_CONFIG
496496
// Configure ranks
497497
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
498498
#endif

torch/csrc/distributed/c10d/init.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3038,7 +3038,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
30383038
"_get_intra_node_comm_usage_counter",
30393039
&::c10d::intra_node_comm::getIntraNodeCommUsageCounter);
30403040

3041-
#ifdef NCCL_HAS_COMM_CTA_CGA
3041+
#ifdef NCCL_HAS_CONFIG
30423042
py::class_<ncclConfig_t>(
30433043
processGroupNCCL,
30443044
"NCCLConfig",
@@ -3064,7 +3064,7 @@ for details.
30643064
[](ncclConfig_t& self, const char* tmp) {
30653065
self.netName = strdup(tmp);
30663066
});
3067-
#endif
3067+
#endif // NCCL_HAS_CONFIG
30683068

30693069
intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
30703070
processGroupNCCL,
@@ -3100,7 +3100,7 @@ Example::
31003100
>>> dist.init_process_group("nccl", pg_options=nccl_options)
31013101
)")
31023102
.def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
3103-
#ifdef NCCL_HAS_COMM_CTA_CGA
3103+
#ifdef NCCL_HAS_CONFIG
31043104
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config)
31053105
#endif
31063106
.def_readwrite(

0 commit comments

Comments
 (0)