Skip to content

Commit 39e7ac8

Browse files
YulunWmeta-codesync[bot]
authored andcommitted
Switch to use new colltrace in Ctran Dist tests
Summary: In ctran distributed tests, we implicitly use old colltrace, which caused them to fail when we switch to new colltrace in D84476753. Change the tests to use new colltrace explicitly first. Reviewed By: tanquer Differential Revision: D85124894 fbshipit-source-id: 09900bb4606da4694187ab9e88d70ae91253e7c2
1 parent 77669c3 commit 39e7ac8

13 files changed

+345
-155
lines changed

comms/ctran/tests/CtranDistAllReduceTest.cc

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
#include <folly/init/Init.h>
4+
#include <folly/json/json.h>
45
#include <gmock/gmock.h>
56
#include <gtest/gtest.h>
67
#include <stdlib.h>
@@ -12,10 +13,11 @@
1213

1314
#include "comms/ctran/Ctran.h"
1415
#include "comms/ctran/algos/AllReduce/AllReduceImpl.h"
16+
#include "comms/ctran/tracing/CollTraceWrapper.h"
1517
#include "comms/testinfra/TestUtils.h"
1618
#include "comms/testinfra/TestsDistUtils.h"
1719
#include "comms/utils/cvars/nccl_cvars.h"
18-
#include "meta/colltrace/CollTrace.h"
20+
#include "meta/commDump.h"
1921

2022
// Reduce the value range to avoid integer overflow when running large count
2123
constexpr size_t VAL_RANGE = 1024;
@@ -37,6 +39,10 @@ class CtranAllReduceTest : public CtranDistBaseTest {
3739
ncclComm_t comm;
3840

3941
void SetUp() override {
42+
setenv("NCCL_COLLTRACE", "trace", 0);
43+
setenv("NCCL_COLLTRACE_USE_NEW_COLLTRACE", "1", 0);
44+
// -1 for not limiting the number of colls to trace
45+
setenv("NCCL_COLLTRACE_RECORD_MAX", "-1", 0);
4046
#ifdef CTRAN_TEST_SOCKET_ONLY_BACKEND
4147
setenv("NCCL_CTRAN_BACKENDS", "socket, nvl", 1);
4248
#endif
@@ -165,14 +171,16 @@ class CtranAllReduceTest : public CtranDistBaseTest {
165171
}
166172

167173
memorySetUp(count, inplace, op, memType);
168-
comm->ctranComm_->collTrace_->resetPastColls();
169174

170175
for (auto& segment : segments) {
171176
void* hdl = nullptr;
172177
NCCLCHECK_TEST(ncclCommRegister(comm, segment.ptr, segment.size, &hdl));
173178
segHandles.push_back(hdl);
174179
}
175180

181+
ASSERT_TRUE(meta::comms::colltrace::testOnlyClearCollTraceRecords(
182+
comm->ctranComm_.get()));
183+
176184
if (inplace == kTestInPlace) {
177185
auto res = allreduceFunc(
178186
recvbuf,
@@ -201,17 +209,26 @@ class CtranAllReduceTest : public CtranDistBaseTest {
201209

202210
verifyResult(count, op);
203211

204-
// CollTrace is updated by a separate thread, need wait for it to finish to
205-
// avoid flaky test
206-
comm->ctranComm_->collTrace_->waitForWorkerFinishQueue();
207-
auto dump = comm->ctranComm_->collTrace_->dump();
208-
EXPECT_EQ(dump.pastColls.size(), 1);
209-
210-
auto lastColl = dump.pastColls.back();
211-
EXPECT_EQ(lastColl.opName, "AllReduce");
212-
EXPECT_EQ(lastColl.count, count);
213-
EXPECT_EQ(lastColl.dataType, dt);
214-
EXPECT_EQ(lastColl.algoName, allReduceAlgoName(algo));
212+
CUDACHECK_TEST(cudaDeviceSynchronize());
213+
// Sleep for a while to make sure all the colls are finished
214+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
215+
216+
ASSERT_TRUE(comm->newCollTrace != nullptr);
217+
auto dumpMap = meta::comms::ncclx::dumpNewCollTrace(*comm->newCollTrace);
218+
219+
EXPECT_NE(dumpMap["CT_pastColls"], "[]");
220+
EXPECT_EQ(dumpMap["CT_pendingColls"], "[]");
221+
EXPECT_EQ(dumpMap["CT_currentColl"], "null");
222+
223+
auto pastCollsJson = folly::parseJson(dumpMap["CT_pastColls"]);
224+
EXPECT_EQ(pastCollsJson.size(), 1);
225+
226+
auto lastColl = pastCollsJson[0];
227+
EXPECT_EQ(lastColl["opName"].asString(), "AllReduce");
228+
EXPECT_EQ(lastColl["count"].asInt(), count);
229+
EXPECT_THAT(
230+
lastColl["algoName"].asString(),
231+
testing::HasSubstr(allReduceAlgoName(algo)));
215232

216233
verifyBackendsUsed(
217234
comm->ctranComm_->ctran_.get(),

comms/ctran/tests/CtranDistAllgatherTests.cc

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
#include "comms/ctran/Ctran.h"
1212
#include "comms/ctran/algos/AllGather/AllGatherImpl.h"
1313
#include "comms/ctran/algos/AllReduce/AllReduceImpl.h"
14+
#include "comms/ctran/tracing/CollTraceWrapper.h"
1415
#include "comms/testinfra/TestUtils.h"
1516
#include "comms/testinfra/TestsCuUtils.h"
1617
#include "comms/testinfra/TestsDistUtils.h"
1718
#include "comms/utils/cvars/nccl_cvars.h"
18-
#include "meta/colltrace/CollTrace.h"
19+
#include "meta/commDump.h"
20+
21+
#include <folly/json/json.h>
1922

2023
class CtranAllgatherTest : public CtranDistBaseTest {
2124
public:
@@ -32,6 +35,10 @@ class CtranAllgatherTest : public CtranDistBaseTest {
3235
ncclComm_t comm;
3336

3437
void SetUp() override {
38+
setenv("NCCL_COLLTRACE", "trace", 0);
39+
setenv("NCCL_COLLTRACE_USE_NEW_COLLTRACE", "1", 0);
40+
// -1 for not limiting the number of colls to trace
41+
setenv("NCCL_COLLTRACE_RECORD_MAX", "-1", 0);
3542
CtranDistBaseTest::SetUp();
3643
comm = commWorld;
3744
segments.clear();
@@ -125,9 +132,6 @@ TEST_P(CtranAllgatherTestParam, AllgatherAlgo) {
125132

126133
// CollTrace will help check whether the specified algo is used
127134
EnvRAII env(NCCL_ALLGATHER_ALGO, algo);
128-
// Ensure CollTrace won't drop any record
129-
EnvRAII envCollTrace(
130-
NCCL_COLLTRACE_RECORD_MAX, iter * (1 + (pairColl != kTestPairNone)));
131135

132136
if (memType == kCuMemAllocDisjoint &&
133137
(!comm->dmaBufSupport || !NCCL_CTRAN_IB_DMABUF_ENABLE)) {
@@ -163,7 +167,8 @@ TEST_P(CtranAllgatherTestParam, AllgatherAlgo) {
163167
std::vector<std::string> expOpNames;
164168
std::vector<std::string> expAlgoNames;
165169

166-
comm->ctranComm_->collTrace_->resetPastColls();
170+
ASSERT_TRUE(meta::comms::colltrace::testOnlyClearCollTraceRecords(
171+
comm->ctranComm_.get()));
167172

168173
for (int x = 0; x < iter; x++) {
169174
expOpNames.push_back("AllGather");
@@ -209,18 +214,25 @@ TEST_P(CtranAllgatherTestParam, AllgatherAlgo) {
209214
{CtranMapperBackend::NVL});
210215
verifyGpeLeak(comm->ctranComm_->ctran_.get());
211216

212-
// CollTrace is updated by a separate thread, need wait for it to finish to
213-
// avoid flaky test
214-
comm->ctranComm_->collTrace_->waitForWorkerFinishQueue();
217+
CUDACHECK_TEST(cudaDeviceSynchronize());
218+
// Sleep for a while to make sure all the colls are finished
219+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
220+
221+
ASSERT_TRUE(comm->newCollTrace != nullptr);
222+
auto dumpMap = meta::comms::ncclx::dumpNewCollTrace(*comm->newCollTrace);
215223

216-
auto dump = comm->ctranComm_->collTrace_->dump();
217-
EXPECT_EQ(dump.pastColls.size(), expOpNames.size());
224+
EXPECT_NE(dumpMap["CT_pastColls"], "[]");
225+
EXPECT_EQ(dumpMap["CT_pendingColls"], "[]");
226+
EXPECT_EQ(dumpMap["CT_currentColl"], "null");
227+
228+
auto pastCollsJson = folly::parseJson(dumpMap["CT_pastColls"]);
229+
EXPECT_EQ(pastCollsJson.size(), expOpNames.size());
218230
int idx = 0;
219-
for (auto& coll : dump.pastColls) {
220-
EXPECT_EQ(coll.opName, expOpNames.at(idx));
221-
EXPECT_EQ(coll.count, count);
222-
EXPECT_EQ(coll.dataType, dt);
223-
EXPECT_EQ(coll.algoName, expAlgoNames.at(idx));
231+
for (const auto& coll : pastCollsJson) {
232+
EXPECT_EQ(coll["opName"].asString(), expOpNames.at(idx));
233+
EXPECT_EQ(coll["count"].asInt(), count);
234+
EXPECT_THAT(
235+
coll["algoName"].asString(), testing::HasSubstr(expAlgoNames.at(idx)));
224236
idx++;
225237
}
226238

@@ -384,8 +396,6 @@ TEST_P(CtranSocketAllgatherTestParam, AllgatherAlgo) {
384396
// CollTrace will help check whether the specified algo is used
385397
EnvRAII env(NCCL_ALLGATHER_ALGO, algo);
386398
// Ensure CollTrace won't drop any record
387-
EnvRAII envCollTrace(
388-
NCCL_COLLTRACE_RECORD_MAX, iter * (1 + (pairColl != kTestPairNone)));
389399

390400
memorySetUp(kMemNcclMemAlloc, offset, count, inplace, pairColl);
391401

@@ -399,7 +409,8 @@ TEST_P(CtranSocketAllgatherTestParam, AllgatherAlgo) {
399409
std::vector<std::string> expOpNames;
400410
std::vector<std::string> expAlgoNames;
401411

402-
comm->ctranComm_->collTrace_->resetPastColls();
412+
ASSERT_TRUE(meta::comms::colltrace::testOnlyClearCollTraceRecords(
413+
comm->ctranComm_.get()));
403414

404415
for (int x = 0; x < iter; x++) {
405416
expOpNames.emplace_back("AllGather");
@@ -445,18 +456,25 @@ TEST_P(CtranSocketAllgatherTestParam, AllgatherAlgo) {
445456
{CtranMapperBackend::NVL});
446457
verifyGpeLeak(comm->ctranComm_->ctran_.get());
447458

448-
// CollTrace is updated by a separate thread, need wait for it to finish to
449-
// avoid flaky test
450-
comm->ctranComm_->collTrace_->waitForWorkerFinishQueue();
459+
CUDACHECK_TEST(cudaDeviceSynchronize());
460+
// Sleep for a while to make sure all the colls are finished
461+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
462+
463+
ASSERT_TRUE(comm->newCollTrace != nullptr);
464+
auto dumpMap = meta::comms::ncclx::dumpNewCollTrace(*comm->newCollTrace);
465+
466+
EXPECT_NE(dumpMap["CT_pastColls"], "[]");
467+
EXPECT_EQ(dumpMap["CT_pendingColls"], "[]");
468+
EXPECT_EQ(dumpMap["CT_currentColl"], "null");
451469

452-
auto dump = comm->ctranComm_->collTrace_->dump();
453-
EXPECT_EQ(dump.pastColls.size(), expOpNames.size());
470+
auto pastCollsJson = folly::parseJson(dumpMap["CT_pastColls"]);
471+
EXPECT_EQ(pastCollsJson.size(), expOpNames.size());
454472
int idx = 0;
455-
for (auto& coll : dump.pastColls) {
456-
EXPECT_EQ(coll.opName, expOpNames.at(idx));
457-
EXPECT_EQ(coll.count, count);
458-
EXPECT_EQ(coll.dataType, dt);
459-
EXPECT_EQ(coll.algoName, expAlgoNames.at(idx));
473+
for (const auto& coll : pastCollsJson) {
474+
EXPECT_EQ(coll["opName"].asString(), expOpNames.at(idx));
475+
EXPECT_EQ(coll["count"].asInt(), count);
476+
EXPECT_THAT(
477+
coll["algoName"].asString(), testing::HasSubstr(expAlgoNames.at(idx)));
460478
idx++;
461479
}
462480

comms/ctran/tests/CtranDistAlltoAllDedupTest.cc

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
#include <folly/init/Init.h>
4+
#include <folly/json/json.h>
45
#include <gmock/gmock.h>
56
#include <gtest/gtest.h>
67
#include <nccl.h>
@@ -10,9 +11,10 @@
1011
#include "CtranUtUtils.h"
1112
#include "comms/ctran/Ctran.h"
1213
#include "comms/ctran/algos/AllToAll/AllToAllDedupImpl.h"
14+
#include "comms/ctran/tracing/CollTraceWrapper.h"
1315
#include "comms/testinfra/TestUtils.h"
1416
#include "comms/testinfra/TestsDistUtils.h"
15-
#include "meta/colltrace/CollTrace.h"
17+
#include "meta/commDump.h"
1618

1719
class ctranAllToAllDedupTest : public CtranDistBaseTest {
1820
public:
@@ -133,6 +135,10 @@ class ctranAllToAllDedupTest : public CtranDistBaseTest {
133135
}
134136

135137
void SetUp() override {
138+
setenv("NCCL_COLLTRACE", "trace", 0);
139+
setenv("NCCL_COLLTRACE_USE_NEW_COLLTRACE", "1", 0);
140+
// -1 for not limiting the number of colls to trace
141+
setenv("NCCL_COLLTRACE_RECORD_MAX", "-1", 0);
136142
CtranDistBaseTest::SetUp();
137143
comm = commWorld;
138144
if (!ctranAllToAllDedupSupport(comm->ctranComm_.get())) {
@@ -164,7 +170,8 @@ class ctranAllToAllDedupTest : public CtranDistBaseTest {
164170
int numSplits = nLocalRanks;
165171
int nNodes = comm->ctranComm_->statex_->nNodes();
166172

167-
comm->ctranComm_->collTrace_->resetPastColls();
173+
ASSERT_TRUE(meta::comms::colltrace::testOnlyClearCollTraceRecords(
174+
comm->ctranComm_.get()));
168175

169176
CtranPersistentRequest* request = nullptr;
170177
void* recvBuf = nullptr;
@@ -226,21 +233,27 @@ class ctranAllToAllDedupTest : public CtranDistBaseTest {
226233
auto destroyRes = ctranAllToAllDedupDestroy(request);
227234
ASSERT_EQ(destroyRes, commSuccess);
228235

229-
// CollTrace is updated by a separate thread, need wait for it to finish to
230-
// avoid flaky test
231-
comm->ctranComm_->collTrace_->waitForWorkerFinishQueue();
236+
CUDACHECK_TEST(cudaDeviceSynchronize());
237+
// Sleep for a while to make sure all the colls are finished
238+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
232239

233-
auto dump = comm->ctranComm_->collTrace_->dump();
240+
ASSERT_TRUE(comm->newCollTrace != nullptr);
241+
auto dumpMap = meta::comms::ncclx::dumpNewCollTrace(*comm->newCollTrace);
242+
243+
EXPECT_NE(dumpMap["CT_pastColls"], "[]");
244+
EXPECT_EQ(dumpMap["CT_pendingColls"], "[]");
245+
EXPECT_EQ(dumpMap["CT_currentColl"], "null");
246+
247+
auto pastCollsJson = folly::parseJson(dumpMap["CT_pastColls"]);
234248
// make sure we are collecting enough records
235249
EXPECT_GE(NCCL_COLLTRACE_RECORD_MAX, numTimesRunExec);
236-
EXPECT_EQ(dump.pastColls.size(), numTimesRunExec);
237-
238-
for (auto& coll : dump.pastColls) {
239-
EXPECT_EQ(coll.opName, "AllToAllDedup");
240-
// Count should be nullOpt for AllToAllV at the moment
241-
EXPECT_THAT(coll.count, ::testing::Eq(std::nullopt));
242-
EXPECT_EQ(coll.dataType, commInt);
243-
EXPECT_EQ(coll.algoName, "ctranAllToAllDedup");
250+
EXPECT_EQ(pastCollsJson.size(), numTimesRunExec);
251+
252+
for (const auto& coll : pastCollsJson) {
253+
EXPECT_EQ(coll["opName"].asString(), "AllToAll_Dedup");
254+
EXPECT_THAT(
255+
coll["algoName"].asString(),
256+
testing::HasSubstr("ctranAllToAllDedup"));
244257
}
245258

246259
size_t sendCount =

comms/ctran/tests/CtranDistAlltoAllPTest.cc

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
#include "comms/ctran/Ctran.h"
1212
#include "comms/ctran/algos/AllToAll/AllToAllPImpl.h"
1313
#include "comms/ctran/algos/AllToAll/AllToAllvImpl.h"
14+
#include "comms/ctran/tracing/CollTraceWrapper.h"
1415
#include "comms/testinfra/TestUtils.h"
1516
#include "comms/testinfra/TestsDistUtils.h"
1617
#include "comms/utils/cvars/nccl_cvars.h"
17-
#include "meta/colltrace/CollTrace.h"
18+
#include "meta/commDump.h"
19+
20+
#include <folly/json/json.h>
1821

1922
class ctranAllToAllPTest : public CtranDistBaseTest {
2023
public:
@@ -87,6 +90,10 @@ class ctranAllToAllPTest : public CtranDistBaseTest {
8790
}
8891

8992
void SetUp() override {
93+
setenv("NCCL_COLLTRACE", "trace", 0);
94+
setenv("NCCL_COLLTRACE_USE_NEW_COLLTRACE", "1", 0);
95+
// -1 for not limiting the number of colls to trace
96+
setenv("NCCL_COLLTRACE_RECORD_MAX", "-1", 0);
9097
CtranDistBaseTest::SetUp();
9198
comm = commWorld;
9299
if (!ctran::AllToAllPSupport(comm->ctranComm_.get())) {
@@ -115,7 +122,8 @@ class ctranAllToAllPTest : public CtranDistBaseTest {
115122
}
116123

117124
void run() {
118-
comm->ctranComm_->collTrace_->resetPastColls();
125+
ASSERT_TRUE(meta::comms::colltrace::testOnlyClearCollTraceRecords(
126+
comm->ctranComm_.get()));
119127

120128
// Initialize double persistent requests using double recv buffer allocated.
121129
std::array<CtranPersistentRequest*, 2> doublePRequests;
@@ -165,31 +173,37 @@ class ctranAllToAllPTest : public CtranDistBaseTest {
165173
ASSERT_EQ(destroyRes, commSuccess);
166174
}
167175

168-
// CollTrace is updated by a separate thread, need wait for it to finish to
169-
// avoid flaky test
170-
comm->ctranComm_->collTrace_->waitForWorkerFinishQueue();
176+
CUDACHECK_TEST(cudaDeviceSynchronize());
177+
// Sleep for a while to make sure all the colls are finished
178+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
179+
180+
ASSERT_TRUE(comm->newCollTrace != nullptr);
181+
auto dumpMap = meta::comms::ncclx::dumpNewCollTrace(*comm->newCollTrace);
182+
183+
EXPECT_NE(dumpMap["CT_pastColls"], "[]");
184+
EXPECT_EQ(dumpMap["CT_pendingColls"], "[]");
185+
EXPECT_EQ(dumpMap["CT_currentColl"], "null");
171186

172-
auto dump = comm->ctranComm_->collTrace_->dump();
187+
auto pastCollsJson = folly::parseJson(dumpMap["CT_pastColls"]);
173188
auto statex = comm->ctranComm_->statex_.get();
174189
// If there are remote peers, AllToAllPInit submits gpe op and was recorded
175190
// by colltrace.
176191
int numTimesRunInit = statex->nNodes() == 1 ? 0 : 2;
177-
EXPECT_EQ(dump.pastColls.size(), numTimesRunInit + numTimesRunExec);
192+
EXPECT_EQ(pastCollsJson.size(), numTimesRunInit + numTimesRunExec);
178193

179194
// Skip the check for the AllToAllPInit (first 2) colls.
180-
for (int i = numTimesRunInit; i < dump.pastColls.size(); i++) {
181-
auto& coll = dump.pastColls[i];
195+
for (int i = numTimesRunInit; i < pastCollsJson.size(); i++) {
196+
const auto& coll = pastCollsJson[i];
182197
if (statex->nNodes() == 1) {
183198
// If only cuda kernel is launched (no IB put), AlltoAllP is essentially
184199
// alltoall because it shares cuda kernel logic with AlltoAll.
185-
EXPECT_EQ(coll.opName, "AllToAll");
200+
EXPECT_EQ(coll["opName"].asString(), "AllToAll");
186201
} else {
187-
EXPECT_EQ(coll.opName, "AllToAllP");
202+
EXPECT_EQ(coll["opName"].asString(), "AllToAllP");
188203
}
189-
EXPECT_THAT(coll.count, ::testing::Eq(counts[i - numTimesRunInit]));
190-
EXPECT_EQ(coll.dataType, commInt);
204+
EXPECT_EQ(coll["count"].asInt(), counts[i - numTimesRunInit]);
191205
EXPECT_EQ(
192-
coll.algoName,
206+
coll["algoName"].asString(),
193207
ctran::alltoallp::AlgoImpl::algoName(NCCL_ALLTOALL_ALGO::ctran));
194208
}
195209

0 commit comments

Comments
 (0)