Skip to content

Commit 9698e1c

Browse files
mlunar-metameta-codesync[bot]
authored andcommitted
Make FastInitTest test Ctran Init
Summary: These diff stack, implements the required updates to rebase NCCL 2.28 for NCCLX with CTRAN integration. It incorporates all changes introduced in version 2.27, applying them on top of NCCL’s latest stable release (2.28). The primary objective of this diff is to enable CTRAN support for NCCLX under NCCL 2.28, ensuring compatibility and leveraging the latest enhancements from the upstream release. In addition, it includes necessary checks inside FastInitTest so that it can ensure the ctran enabling as expected. In this specific diff, I include fast init test with the required codes so that it can test Ctran init. Reviewed By: dboyda Differential Revision: D85065685 fbshipit-source-id: aee990de51be0ff1f9eb28640da1f19f323b4498
1 parent 92e1b12 commit 9698e1c

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

comms/ncclx/v2_27/meta/tests/FastInitTest.cc

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "comm.h"
1212
#include "comms/testinfra/TestUtils.h"
1313
#include "comms/testinfra/TestsDistUtils.h"
14+
#include "comms/testinfra/tests_common.cuh"
1415
#include "nccl.h"
1516

1617
void printCommStateX(const ncclComm& comm) {
@@ -22,12 +23,31 @@ void printCommStateX(const ncclComm& comm) {
2223
VLOG(1) << "=================";
2324
}
2425

26+
void validateCtranInitialization(
27+
ncclComm_t comm,
28+
int expectedRank,
29+
int expectedNRanks,
30+
int expectedCudaDev) {
31+
EXPECT_EQ(comm->rank, expectedRank);
32+
EXPECT_EQ(comm->nRanks, expectedNRanks);
33+
EXPECT_EQ(comm->cudaDev, expectedCudaDev);
34+
35+
ASSERT_NE(nullptr, comm->ctranComm_);
36+
ASSERT_NE(nullptr, comm->ctranComm_->statex_);
37+
ASSERT_NE(nullptr, comm->ctranComm_->bootstrap_);
38+
ASSERT_NE(nullptr, comm->ctranComm_->collTrace_);
39+
ASSERT_NE(nullptr, comm->ctranComm_->ctran_);
40+
EXPECT_TRUE(ctranInitialized(comm->ctranComm_.get()));
41+
EXPECT_EQ(comm->commHash, comm->ctranComm_->statex_->commHash());
42+
}
43+
2544
TEST_P(NcclxBaseTestFixture, NcclCommInitWorldAndDestroy) {
26-
ncclComm_t rootComm;
45+
ncclComm_t rootComm = nullptr;
2746
ncclUniqueId commId;
2847
NCCLCHECK_TEST(
2948
ncclCommInitRankConfig(&rootComm, numRanks, commId, globalRank, nullptr));
3049
ASSERT_NE(nullptr, rootComm);
50+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
3151

3252
const auto statex = rootComm->ctranComm_->statex_.get();
3353
if (statex->nNodes() == 1) {
@@ -45,11 +65,12 @@ TEST_P(NcclxBaseTestFixture, NcclCommInitWorldAndDestroy) {
4565
}
4666

4767
TEST_P(NcclxBaseTestFixture, NcclCommInitWorldAndAbort) {
48-
ncclComm_t rootComm;
68+
ncclComm_t rootComm = nullptr;
4969
ncclUniqueId commId;
5070
NCCLCHECK_TEST(
5171
ncclCommInitRankConfig(&rootComm, numRanks, commId, globalRank, nullptr));
5272
ASSERT_NE(nullptr, rootComm);
73+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
5374

5475
void* sendBuf;
5576
void* recvBuf;
@@ -87,13 +108,14 @@ void compareComm(const ncclComm& comm1, const ncclComm& comm2) {
87108
}
88109

89110
TEST_P(NcclxBaseTestFixture, NcclCommSplit) {
90-
ncclComm_t rootComm;
111+
ncclComm_t rootComm = nullptr;
91112
ncclUniqueId commId;
92113
NCCLCHECK_TEST(
93114
ncclCommInitRankConfig(&rootComm, numRanks, commId, globalRank, nullptr));
94115
ASSERT_NE(nullptr, rootComm);
116+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
95117

96-
ncclComm_t childComm;
118+
ncclComm_t childComm = nullptr;
97119
ncclConfig_t childCommConfig = NCCL_CONFIG_INITIALIZER;
98120
childCommConfig.commDesc = "child_communicator";
99121
int groupSize = rootComm->ctranComm_->statex_.get()->nRanks() / 2;
@@ -143,11 +165,12 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplit) {
143165
// we can split the same group rank multiple times
144166
// we should expect unique hash for each communicator
145167
TEST_P(NcclxBaseTestFixture, NcclCommSplitDuplicateGroups) {
146-
ncclComm_t rootComm;
168+
ncclComm_t rootComm = nullptr;
147169
ncclUniqueId commId;
148170
NCCLCHECK_TEST(
149171
ncclCommInitRankConfig(&rootComm, numRanks, commId, globalRank, nullptr));
150172
ASSERT_NE(nullptr, rootComm);
173+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
151174
const auto statex = rootComm->ctranComm_->statex_.get();
152175

153176
if (statex->nNodes() == 1) {
@@ -166,13 +189,13 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitDuplicateGroups) {
166189
childCommConfig.splitGroupRanks = groupRanks;
167190
childCommConfig.splitGroupSize = groupSize;
168191

169-
ncclComm_t childComm1;
192+
ncclComm_t childComm1 = nullptr;
170193
NCCLCHECK_TEST(ncclCommSplit(
171194
rootComm, globalRank % 2, globalRank / 2, &childComm1, &childCommConfig));
172195
ASSERT_NE(nullptr, childComm1);
173196

174197
// split again with same config
175-
ncclComm_t childComm2;
198+
ncclComm_t childComm2 = nullptr;
176199
NCCLCHECK_TEST(ncclCommSplit(
177200
rootComm, globalRank % 2, globalRank / 2, &childComm2, &childCommConfig));
178201
ASSERT_NE(nullptr, childComm2);
@@ -189,11 +212,12 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitDuplicateGroups) {
189212
}
190213

191214
TEST_P(NcclxBaseTestFixture, WorldCommAllGather) {
192-
ncclComm_t rootComm;
215+
ncclComm_t rootComm = nullptr;
193216
ncclUniqueId commId;
194217
NCCLCHECK_TEST(
195218
ncclCommInitRankConfig(&rootComm, numRanks, commId, globalRank, nullptr));
196219
ASSERT_NE(nullptr, rootComm);
220+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
197221
const auto statex = rootComm->ctranComm_->statex_.get();
198222

199223
if (statex->nNodes() == 1) {
@@ -238,18 +262,19 @@ TEST_P(NcclxBaseTestFixture, WorldCommAllGather) {
238262
}
239263

240264
TEST_P(NcclxBaseTestFixture, ChildCommAllGather) {
241-
ncclComm_t rootComm;
265+
ncclComm_t rootComm = nullptr;
242266
ncclUniqueId commId;
243267
NCCLCHECK_TEST(
244268
ncclCommInitRankConfig(&rootComm, numRanks, commId, globalRank, nullptr));
245269
ASSERT_NE(nullptr, rootComm);
270+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
246271
const auto statex = rootComm->ctranComm_->statex_.get();
247272
if (statex->nNodes() == 1) {
248273
NCCLCHECK_TEST(ncclCommDestroy(rootComm));
249274
GTEST_SKIP() << "Skip test since only one node provided";
250275
}
251276

252-
ncclComm_t childComm;
277+
ncclComm_t childComm = nullptr;
253278
ncclConfig_t childCommConfig = NCCL_CONFIG_INITIALIZER;
254279
childCommConfig.commDesc = "child_communicator";
255280
int groupSize = rootComm->ctranComm_->statex_.get()->nRanks() / 2;
@@ -303,7 +328,7 @@ TEST_P(NcclxBaseTestFixture, ChildCommAllGather) {
303328
}
304329

305330
TEST_P(NcclxBaseTestFixture, NcclCommSplitNoColor) {
306-
ncclComm_t rootComm;
331+
ncclComm_t rootComm = nullptr;
307332
ncclComm_t childComm = NCCL_COMM_NULL;
308333
ncclUniqueId commId;
309334
ncclConfig_t rootConfig = NCCL_CONFIG_INITIALIZER;
@@ -314,6 +339,7 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitNoColor) {
314339
NCCLCHECK_TEST(ncclCommInitRankConfig(
315340
&rootComm, numRanks, commId, globalRank, &rootConfig));
316341
ASSERT_NE(nullptr, rootComm);
342+
validateCtranInitialization(rootComm, globalRank, numRanks, localRank);
317343

318344
const auto statex = rootComm->ctranComm_->statex_.get();
319345
EXPECT_NE(statex, nullptr);
@@ -360,7 +386,7 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitNoColor) {
360386
}
361387

362388
TEST_P(NcclxBaseTestFixture, NcclCommInitWithDifferentCommDesc) {
363-
ncclComm_t comm1, comm2;
389+
ncclComm_t comm1 = nullptr, comm2 = nullptr;
364390
ncclUniqueId commId1, commId2;
365391

366392
// Create first comm with commDesc "comm_desc_1"
@@ -369,13 +395,15 @@ TEST_P(NcclxBaseTestFixture, NcclCommInitWithDifferentCommDesc) {
369395
NCCLCHECK_TEST(
370396
ncclCommInitRankConfig(&comm1, numRanks, commId1, globalRank, &config1));
371397
ASSERT_NE(nullptr, comm1);
398+
validateCtranInitialization(comm1, globalRank, numRanks, localRank);
372399

373400
// Create second comm with commDesc "comm_desc_2"
374401
ncclConfig_t config2 = NCCL_CONFIG_INITIALIZER;
375402
config2.commDesc = "comm_desc_2";
376403
NCCLCHECK_TEST(
377404
ncclCommInitRankConfig(&comm2, numRanks, commId2, globalRank, &config2));
378405
ASSERT_NE(nullptr, comm2);
406+
validateCtranInitialization(comm2, globalRank, numRanks, localRank);
379407

380408
// Verify both comms are valid and have correct properties
381409
const auto statex1 = comm1->ctranComm_->statex_.get();
@@ -401,6 +429,8 @@ INSTANTIATE_TEST_SUITE_P(
401429
NcclxBaseTestFixture,
402430
testing::Values(NcclxEnvs({
403431
{"NCCL_FASTINIT_MODE", "ring_hybrid"},
432+
{"NCCL_CTRAN_ENABLE", "1"},
433+
{"NCCL_COLLTRACE", "trace"},
404434
})),
405435
[](const testing::TestParamInfo<NcclxBaseTestFixture::ParamType>& info) {
406436
// generate test-name for a given NcclxEnvs

0 commit comments

Comments
 (0)