@@ -220,17 +220,50 @@ static ncclResult_t ptrSupport_v2(int dev_id, int *supportedTypes)
220
220
}
221
221
222
222
223
- static ncclResult_t connect_v7 (int dev, void * handle, void ** sendComm,
224
- ncclNetDeviceHandle_v7_t** sendDevComm)
223
+ // Nvidia introduced the ability to have part of the communication driven by a
224
+ // cuda kernel, which requires a version-specific device pointer be passed
225
+ // through the accept/connect APIs. We don't support that interface, so we
226
+ // never need to look at the third argument. Rather than pollute the api
227
+ // interface, just declare these wrappers in the nvidia interface.
228
+ static ncclResult_t nccl_net_ofi_connect_v7 (int dev, void * handle, void ** sendComm,
229
+ ncclNetDeviceHandle_v7_t** sendDevComm)
225
230
{
226
231
return nccl_net_ofi_connect (dev, handle, sendComm);
227
232
}
228
233
229
234
230
- static ncclResult_t accept_v7 ( void * listenComm , void ** recvComm ,
231
- ncclNetDeviceHandle_v7_t ** recvDevComm )
235
+ static ncclResult_t nccl_net_ofi_connect_v8 ( int dev, void * handle , void ** sendComm ,
236
+ ncclNetDeviceHandle_v8_t ** sendDevComm )
232
237
{
233
- return nccl_net_ofi_accept (listenComm, recvComm);
238
+ return nccl_net_ofi_connect (dev, handle, sendComm);
239
+ }
240
+
241
+
242
+ static ncclResult_t nccl_net_ofi_connect_v9 (int dev, void * handle, void ** sendComm,
243
+ ncclNetDeviceHandle_v9_t** sendDevComm)
244
+ {
245
+ return nccl_net_ofi_connect (dev, handle, sendComm);
246
+ }
247
+
248
+
249
+ static ncclResult_t nccl_net_ofi_accept_v7 (void * listenComm, void ** recvComm,
250
+ ncclNetDeviceHandle_v7_t** recvDevComm)
251
+ {
252
+ return nccl_net_ofi_accept (listenComm, recvComm);
253
+ }
254
+
255
+
256
+ static ncclResult_t nccl_net_ofi_accept_v8 (void * listenComm, void ** recvComm,
257
+ ncclNetDeviceHandle_v8_t** recvDevComm)
258
+ {
259
+ return nccl_net_ofi_accept (listenComm, recvComm);
260
+ }
261
+
262
+
263
+ static ncclResult_t nccl_net_ofi_accept_v9 (void * listenComm, void ** recvComm,
264
+ ncclNetDeviceHandle_v9_t** recvDevComm)
265
+ {
266
+ return nccl_net_ofi_accept (listenComm, recvComm);
234
267
}
235
268
236
269
@@ -339,8 +372,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v7_t ncclNetPlugin_v7 = {
339
372
.devices = nccl_net_ofi_devices,
340
373
.getProperties = getProperties_v7,
341
374
.listen = nccl_net_ofi_listen,
342
- .connect = connect_v7 ,
343
- .accept = accept_v7 ,
375
+ .connect = nccl_net_ofi_connect_v7 ,
376
+ .accept = nccl_net_ofi_accept_v7 ,
344
377
.regMr = nccl_net_ofi_regMr_v7,
345
378
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
346
379
.deregMr = nccl_net_ofi_deregMr,
@@ -361,8 +394,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v8_t ncclNetPlugin_v8 = {
361
394
.devices = nccl_net_ofi_devices,
362
395
.getProperties = getProperties_v8,
363
396
.listen = nccl_net_ofi_listen,
364
- .connect = connect_v7 ,
365
- .accept = accept_v7 ,
397
+ .connect = nccl_net_ofi_connect_v8 ,
398
+ .accept = nccl_net_ofi_accept_v8 ,
366
399
.regMr = nccl_net_ofi_regMr,
367
400
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
368
401
.deregMr = nccl_net_ofi_deregMr,
@@ -383,8 +416,8 @@ NCCL_OFI_EXPORT_SYMBOL ncclNet_v9_t ncclNetPlugin_v9 = {
383
416
.devices = nccl_net_ofi_devices,
384
417
.getProperties = getProperties_v9,
385
418
.listen = nccl_net_ofi_listen,
386
- .connect = connect_v7 ,
387
- .accept = accept_v7 ,
419
+ .connect = nccl_net_ofi_connect_v9 ,
420
+ .accept = nccl_net_ofi_accept_v9 ,
388
421
.regMr = nccl_net_ofi_regMr,
389
422
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
390
423
.deregMr = nccl_net_ofi_deregMr,
0 commit comments