Skip to content

Commit 7ddf53b

Browse files
committed
nvidia: Add v8/v9 connect interfaces
We shouldn't rely on the the devComm type being the same across api versions, so properly plumb out the interface. Also add a note about why the code is in the nvidia file rather than the more expected api file. Signed-off-by: Brian Barrett <[email protected]>
1 parent c59f53d commit 7ddf53b

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

include/nccl_ofi_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ ncclResult_t nccl_net_ofi_devices(int *ndev);
1616
ncclResult_t nccl_net_ofi_get_properties(int dev, struct nccl_ofi_properties *ofi_properties);
1717
ncclResult_t nccl_net_ofi_listen(int dev, void *handle, void **listenComm);
1818
ncclResult_t nccl_net_ofi_listen_v4(int dev, void* handle, void** listenComm);
19+
// Nvidia introduced the ability to have part of the communication driven by a
20+
// cuda kernel, which requires a version-specific device pointer be passed
21+
// through the accept/connect APIs. Rather than list all those connect calls
22+
// here, we just declare them in the nvidia interface file to keep this list sane.
1923
ncclResult_t nccl_net_ofi_connect(int dev, void* handle, void** sendComm);
2024
ncclResult_t nccl_net_ofi_connect_v4(int dev, void* handle, void** sendComm);
2125
ncclResult_t nccl_net_ofi_accept(void *listenComm, void **recvComm);

src/nccl_ofi_interface_nvidia.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,17 +220,50 @@ static ncclResult_t ptrSupport_v2(int dev_id, int *supportedTypes)
220220
}
221221

222222

223-
static ncclResult_t connect_v7(int dev, void* handle, void** sendComm,
224-
ncclNetDeviceHandle_v7_t** sendDevComm)
223+
// Nvidia introduced the ability to have part of the communication driven by a
224+
// cuda kernel, which requires a version-specific device pointer be passed
225+
// through the accept/connect APIs. We don't support that interface, so we
226+
// never need to look at the third argument. Rather than pollute the api
227+
// interface, just declare these wrappers in the nvidia interface.
228+
static ncclResult_t nccl_net_ofi_connect_v7(int dev, void* handle, void** sendComm,
229+
ncclNetDeviceHandle_v7_t** sendDevComm)
225230
{
226231
return nccl_net_ofi_connect(dev, handle, sendComm);
227232
}
228233

229234

230-
static ncclResult_t accept_v7(void* listenComm, void** recvComm,
231-
ncclNetDeviceHandle_v7_t** recvDevComm)
235+
static ncclResult_t nccl_net_ofi_connect_v8(int dev, void* handle, void** sendComm,
236+
ncclNetDeviceHandle_v8_t** sendDevComm)
232237
{
233-
return nccl_net_ofi_accept(listenComm, recvComm);
238+
return nccl_net_ofi_connect(dev, handle, sendComm);
239+
}
240+
241+
242+
static ncclResult_t nccl_net_ofi_connect_v9(int dev, void* handle, void** sendComm,
243+
ncclNetDeviceHandle_v9_t** sendDevComm)
244+
{
245+
return nccl_net_ofi_connect(dev, handle, sendComm);
246+
}
247+
248+
249+
static ncclResult_t nccl_net_ofi_accept_v7(void* listenComm, void** recvComm,
250+
ncclNetDeviceHandle_v7_t** recvDevComm)
251+
{
252+
return nccl_net_ofi_accept(listenComm, recvComm);
253+
}
254+
255+
256+
static ncclResult_t nccl_net_ofi_accept_v8(void* listenComm, void** recvComm,
257+
ncclNetDeviceHandle_v8_t** recvDevComm)
258+
{
259+
return nccl_net_ofi_accept(listenComm, recvComm);
260+
}
261+
262+
263+
static ncclResult_t nccl_net_ofi_accept_v9(void* listenComm, void** recvComm,
264+
ncclNetDeviceHandle_v9_t** recvDevComm)
265+
{
266+
return nccl_net_ofi_accept(listenComm, recvComm);
234267
}
235268

236269

@@ -339,8 +372,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v7_t ncclNetPlugin_v7 = {
339372
.devices = nccl_net_ofi_devices,
340373
.getProperties = getProperties_v7,
341374
.listen = nccl_net_ofi_listen,
342-
.connect = connect_v7,
343-
.accept = accept_v7,
375+
.connect = nccl_net_ofi_connect_v7,
376+
.accept = nccl_net_ofi_accept_v7,
344377
.regMr = nccl_net_ofi_regMr_v7,
345378
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
346379
.deregMr = nccl_net_ofi_deregMr,
@@ -361,8 +394,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v8_t ncclNetPlugin_v8 = {
361394
.devices = nccl_net_ofi_devices,
362395
.getProperties = getProperties_v8,
363396
.listen = nccl_net_ofi_listen,
364-
.connect = connect_v7,
365-
.accept = accept_v7,
397+
.connect = nccl_net_ofi_connect_v8,
398+
.accept = nccl_net_ofi_accept_v8,
366399
.regMr = nccl_net_ofi_regMr,
367400
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
368401
.deregMr = nccl_net_ofi_deregMr,
@@ -383,8 +416,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v9_t ncclNetPlugin_v9 = {
383416
.devices = nccl_net_ofi_devices,
384417
.getProperties = getProperties_v9,
385418
.listen = nccl_net_ofi_listen,
386-
.connect = connect_v7,
387-
.accept = accept_v7,
419+
.connect = nccl_net_ofi_connect_v9,
420+
.accept = nccl_net_ofi_accept_v9,
388421
.regMr = nccl_net_ofi_regMr,
389422
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
390423
.deregMr = nccl_net_ofi_deregMr,

0 commit comments

Comments
 (0)