Skip to content

Commit 1c24ac7

Browse files
pavanbalajimeta-codesync[bot]
authored andcommitted
Fix to failing invalid device test
Summary: The InitializationFailsWithInvalidDeviceId test was failing because of an error in the cuda mock. Reviewed By: MittalMakwana Differential Revision: D85512707 fbshipit-source-id: 48123587312f8d9838e948844b0f74ab26c2c56b
1 parent dcfca09 commit 1c24ac7

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

comms/torchcomms/ncclx/tests/unit/cpp/TorchCommNCCLXTest.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,23 @@ TEST_F(TorchCommNCCLXTest, InitializationFailsWithInvalidDeviceId) {
141141
comm->setCudaApi(cuda_mock_);
142142
comm->setNcclApi(nccl_mock_);
143143

144+
// Mock getDeviceCount to return a valid device count (needed for rank %
145+
// device_count)
146+
EXPECT_CALL(*cuda_mock_, getDeviceCount(_))
147+
.WillOnce(DoAll(SetArgPointee<0>(1), Return(cudaSuccess)));
148+
149+
// Mock malloc for barrier buffer allocation in bootstrap constructor
150+
EXPECT_CALL(*cuda_mock_, malloc(_, sizeof(float)))
151+
.Times(2)
152+
.WillRepeatedly(DoAll(
153+
SetArgPointee<0>(reinterpret_cast<void*>(0x1000)),
154+
Return(cudaSuccess)));
155+
156+
// Mock free for barrier buffer deallocation in bootstrap destructor
157+
EXPECT_CALL(*cuda_mock_, free(reinterpret_cast<void*>(0x1000)))
158+
.Times(2)
159+
.WillRepeatedly(Return(cudaSuccess));
160+
144161
// Mock CUDA API to be called with device ID 0, since the boostrap
145162
// logic will assign a device ID in this case based on the rank.
146163
EXPECT_CALL(*cuda_mock_, setDevice(0))

comms/torchcomms/ncclx/tests/unit/cpp/mocks/CudaMock.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ void CudaMock::setupDefaultBehaviors() {
2626
SetArgPointee<1>(std::numeric_limits<int>::max()),
2727
Return(cudaSuccess)));
2828

29+
ON_CALL(*this, getDeviceCount(_))
30+
.WillByDefault(DoAll(SetArgPointee<0>(1), Return(cudaSuccess)));
31+
2932
// Stream management - return success by default
3033
ON_CALL(*this, getStreamPriorityRange(_, _))
3134
.WillByDefault(DoAll(

0 commit comments

Comments
 (0)