1111#include " comm.h"
1212#include " comms/testinfra/TestUtils.h"
1313#include " comms/testinfra/TestsDistUtils.h"
14+ #include " comms/testinfra/tests_common.cuh"
1415#include " nccl.h"
1516
1617void printCommStateX (const ncclComm& comm) {
@@ -22,12 +23,31 @@ void printCommStateX(const ncclComm& comm) {
2223 VLOG (1 ) << " =================" ;
2324}
2425
26+ void validateCtranInitialization (
27+ ncclComm_t comm,
28+ int expectedRank,
29+ int expectedNRanks,
30+ int expectedCudaDev) {
31+ EXPECT_EQ (comm->rank , expectedRank);
32+ EXPECT_EQ (comm->nRanks , expectedNRanks);
33+ EXPECT_EQ (comm->cudaDev , expectedCudaDev);
34+
35+ ASSERT_NE (nullptr , comm->ctranComm_ );
36+ ASSERT_NE (nullptr , comm->ctranComm_ ->statex_ );
37+ ASSERT_NE (nullptr , comm->ctranComm_ ->bootstrap_ );
38+ ASSERT_NE (nullptr , comm->ctranComm_ ->collTrace_ );
39+ ASSERT_NE (nullptr , comm->ctranComm_ ->ctran_ );
40+ EXPECT_TRUE (ctranInitialized (comm->ctranComm_ .get ()));
41+ EXPECT_EQ (comm->commHash , comm->ctranComm_ ->statex_ ->commHash ());
42+ }
43+
2544TEST_P (NcclxBaseTestFixture, NcclCommInitWorldAndDestroy) {
26- ncclComm_t rootComm;
45+ ncclComm_t rootComm = nullptr ;
2746 ncclUniqueId commId;
2847 NCCLCHECK_TEST (
2948 ncclCommInitRankConfig (&rootComm, numRanks, commId, globalRank, nullptr ));
3049 ASSERT_NE (nullptr , rootComm);
50+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
3151
3252 const auto statex = rootComm->ctranComm_ ->statex_ .get ();
3353 if (statex->nNodes () == 1 ) {
@@ -45,11 +65,12 @@ TEST_P(NcclxBaseTestFixture, NcclCommInitWorldAndDestroy) {
4565}
4666
4767TEST_P (NcclxBaseTestFixture, NcclCommInitWorldAndAbort) {
48- ncclComm_t rootComm;
68+ ncclComm_t rootComm = nullptr ;
4969 ncclUniqueId commId;
5070 NCCLCHECK_TEST (
5171 ncclCommInitRankConfig (&rootComm, numRanks, commId, globalRank, nullptr ));
5272 ASSERT_NE (nullptr , rootComm);
73+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
5374
5475 void * sendBuf;
5576 void * recvBuf;
@@ -87,13 +108,14 @@ void compareComm(const ncclComm& comm1, const ncclComm& comm2) {
87108}
88109
89110TEST_P (NcclxBaseTestFixture, NcclCommSplit) {
90- ncclComm_t rootComm;
111+ ncclComm_t rootComm = nullptr ;
91112 ncclUniqueId commId;
92113 NCCLCHECK_TEST (
93114 ncclCommInitRankConfig (&rootComm, numRanks, commId, globalRank, nullptr ));
94115 ASSERT_NE (nullptr , rootComm);
116+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
95117
96- ncclComm_t childComm;
118+ ncclComm_t childComm = nullptr ;
97119 ncclConfig_t childCommConfig = NCCL_CONFIG_INITIALIZER;
98120 childCommConfig.commDesc = " child_communicator" ;
99121 int groupSize = rootComm->ctranComm_ ->statex_ .get ()->nRanks () / 2 ;
@@ -143,11 +165,12 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplit) {
143165// we can split the same group rank multiple times
144166// we should expect unique hash for each communicator
145167TEST_P (NcclxBaseTestFixture, NcclCommSplitDuplicateGroups) {
146- ncclComm_t rootComm;
168+ ncclComm_t rootComm = nullptr ;
147169 ncclUniqueId commId;
148170 NCCLCHECK_TEST (
149171 ncclCommInitRankConfig (&rootComm, numRanks, commId, globalRank, nullptr ));
150172 ASSERT_NE (nullptr , rootComm);
173+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
151174 const auto statex = rootComm->ctranComm_ ->statex_ .get ();
152175
153176 if (statex->nNodes () == 1 ) {
@@ -166,13 +189,13 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitDuplicateGroups) {
166189 childCommConfig.splitGroupRanks = groupRanks;
167190 childCommConfig.splitGroupSize = groupSize;
168191
169- ncclComm_t childComm1;
192+ ncclComm_t childComm1 = nullptr ;
170193 NCCLCHECK_TEST (ncclCommSplit (
171194 rootComm, globalRank % 2 , globalRank / 2 , &childComm1, &childCommConfig));
172195 ASSERT_NE (nullptr , childComm1);
173196
174197 // split again with same config
175- ncclComm_t childComm2;
198+ ncclComm_t childComm2 = nullptr ;
176199 NCCLCHECK_TEST (ncclCommSplit (
177200 rootComm, globalRank % 2 , globalRank / 2 , &childComm2, &childCommConfig));
178201 ASSERT_NE (nullptr , childComm2);
@@ -189,11 +212,12 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitDuplicateGroups) {
189212}
190213
191214TEST_P (NcclxBaseTestFixture, WorldCommAllGather) {
192- ncclComm_t rootComm;
215+ ncclComm_t rootComm = nullptr ;
193216 ncclUniqueId commId;
194217 NCCLCHECK_TEST (
195218 ncclCommInitRankConfig (&rootComm, numRanks, commId, globalRank, nullptr ));
196219 ASSERT_NE (nullptr , rootComm);
220+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
197221 const auto statex = rootComm->ctranComm_ ->statex_ .get ();
198222
199223 if (statex->nNodes () == 1 ) {
@@ -238,18 +262,19 @@ TEST_P(NcclxBaseTestFixture, WorldCommAllGather) {
238262}
239263
240264TEST_P (NcclxBaseTestFixture, ChildCommAllGather) {
241- ncclComm_t rootComm;
265+ ncclComm_t rootComm = nullptr ;
242266 ncclUniqueId commId;
243267 NCCLCHECK_TEST (
244268 ncclCommInitRankConfig (&rootComm, numRanks, commId, globalRank, nullptr ));
245269 ASSERT_NE (nullptr , rootComm);
270+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
246271 const auto statex = rootComm->ctranComm_ ->statex_ .get ();
247272 if (statex->nNodes () == 1 ) {
248273 NCCLCHECK_TEST (ncclCommDestroy (rootComm));
249274 GTEST_SKIP () << " Skip test since only one node provided" ;
250275 }
251276
252- ncclComm_t childComm;
277+ ncclComm_t childComm = nullptr ;
253278 ncclConfig_t childCommConfig = NCCL_CONFIG_INITIALIZER;
254279 childCommConfig.commDesc = " child_communicator" ;
255280 int groupSize = rootComm->ctranComm_ ->statex_ .get ()->nRanks () / 2 ;
@@ -303,7 +328,7 @@ TEST_P(NcclxBaseTestFixture, ChildCommAllGather) {
303328}
304329
305330TEST_P (NcclxBaseTestFixture, NcclCommSplitNoColor) {
306- ncclComm_t rootComm;
331+ ncclComm_t rootComm = nullptr ;
307332 ncclComm_t childComm = NCCL_COMM_NULL;
308333 ncclUniqueId commId;
309334 ncclConfig_t rootConfig = NCCL_CONFIG_INITIALIZER;
@@ -314,6 +339,7 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitNoColor) {
314339 NCCLCHECK_TEST (ncclCommInitRankConfig (
315340 &rootComm, numRanks, commId, globalRank, &rootConfig));
316341 ASSERT_NE (nullptr , rootComm);
342+ validateCtranInitialization (rootComm, globalRank, numRanks, localRank);
317343
318344 const auto statex = rootComm->ctranComm_ ->statex_ .get ();
319345 EXPECT_NE (statex, nullptr );
@@ -360,7 +386,7 @@ TEST_P(NcclxBaseTestFixture, NcclCommSplitNoColor) {
360386}
361387
362388TEST_P (NcclxBaseTestFixture, NcclCommInitWithDifferentCommDesc) {
363- ncclComm_t comm1, comm2;
389+ ncclComm_t comm1 = nullptr , comm2 = nullptr ;
364390 ncclUniqueId commId1, commId2;
365391
366392 // Create first comm with commDesc "comm_desc_1"
@@ -369,13 +395,15 @@ TEST_P(NcclxBaseTestFixture, NcclCommInitWithDifferentCommDesc) {
369395 NCCLCHECK_TEST (
370396 ncclCommInitRankConfig (&comm1, numRanks, commId1, globalRank, &config1));
371397 ASSERT_NE (nullptr , comm1);
398+ validateCtranInitialization (comm1, globalRank, numRanks, localRank);
372399
373400 // Create second comm with commDesc "comm_desc_2"
374401 ncclConfig_t config2 = NCCL_CONFIG_INITIALIZER;
375402 config2.commDesc = " comm_desc_2" ;
376403 NCCLCHECK_TEST (
377404 ncclCommInitRankConfig (&comm2, numRanks, commId2, globalRank, &config2));
378405 ASSERT_NE (nullptr , comm2);
406+ validateCtranInitialization (comm2, globalRank, numRanks, localRank);
379407
380408 // Verify both comms are valid and have correct properties
381409 const auto statex1 = comm1->ctranComm_ ->statex_ .get ();
@@ -401,6 +429,8 @@ INSTANTIATE_TEST_SUITE_P(
401429 NcclxBaseTestFixture,
402430 testing::Values (NcclxEnvs({
403431 {" NCCL_FASTINIT_MODE" , " ring_hybrid" },
432+ {" NCCL_CTRAN_ENABLE" , " 1" },
433+ {" NCCL_COLLTRACE" , " trace" },
404434 })),
405435 [](const testing::TestParamInfo<NcclxBaseTestFixture::ParamType>& info) {
406436 // generate test-name for a given NcclxEnvs
0 commit comments