Skip to content

Commit 7a707a5

Browse files
hershys-awsbwbarrett
authored andcommitted
api: Update and increment send/recv functions
1 parent 92fb93f commit 7a707a5

File tree

5 files changed

+48
-61
lines changed

5 files changed

+48
-61
lines changed

include/nccl_ofi.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ struct nccl_net_ofi_send_comm {
558558
*/
559559
int (*deregMr)(nccl_net_ofi_send_comm_t *send_comm, nccl_net_ofi_mr_handle_t *mhandle);
560560

561-
int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
561+
int (*send)(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
562562
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **req);
563563

564564
int (*close)(nccl_net_ofi_send_comm_t *send_comm);
@@ -591,7 +591,7 @@ struct nccl_net_ofi_recv_comm {
591591
*/
592592
int (*deregMr)(nccl_net_ofi_recv_comm_t *recv_comm, nccl_net_ofi_mr_handle_t *mhandle);
593593

594-
int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes, int *tags,
594+
int (*recv)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, size_t *sizes, int *tags,
595595
nccl_net_ofi_mr_handle_t **mhandles, nccl_net_ofi_req_t **req);
596596

597597
int (*flush)(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **data, int *sizes,

include/tracing_impl/lttng.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ LTTNG_UST_TRACEPOINT_EVENT(
5252
Send,
5353
LTTNG_UST_TP_ARGS(
5454
int, dev,
55-
int, size,
55+
size_t, size,
5656
void *, comm,
5757
uint16_t, msg_seq_num,
5858
void *, request,
5959
void *, nccl_req
6060
),
6161
LTTNG_UST_TP_FIELDS(
6262
lttng_ust_field_integer(int, dev, dev)
63-
lttng_ust_field_integer(int, size, size)
63+
lttng_ust_field_integer(size_t, size, size)
6464
lttng_ust_field_integer_hex(uint64_t, comm, (uint64_t)comm)
6565
lttng_ust_field_integer(uint16_t, msg_seq_num, msg_seq_num)
6666
lttng_ust_field_integer_hex(uint64_t, request, (uint64_t)request)
@@ -238,14 +238,14 @@ LTTNG_UST_TRACEPOINT_EVENT(
238238
LTTNG_UST_TP_ARGS(
239239
int, dev,
240240
void *, comm,
241-
int, size,
241+
size_t, size,
242242
void *, request,
243243
void *, nccl_req
244244
),
245245
LTTNG_UST_TP_FIELDS(
246246
lttng_ust_field_integer(int, dev, dev)
247247
lttng_ust_field_integer_hex(uint64_t, comm, (uint64_t)comm)
248-
lttng_ust_field_integer(int, size, size)
248+
lttng_ust_field_integer(size_t, size, size)
249249
lttng_ust_field_integer_hex(uint64_t, request, (uint64_t)request)
250250
lttng_ust_field_integer_hex(uint64_t, nccl_req, (uint64_t)nccl_req)
251251
)

src/nccl_ofi_api.cpp

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -575,13 +575,19 @@ ncclResult_t nccl_net_ofi_isend_v2(void* sendComm, void* data, int size,
575575
}
576576

577577

578-
ncclResult_t nccl_net_ofi_isend_v5(void *sComm, void* data, int size,
579-
int tag, void *mhandle, void** req)
578+
ncclResult_t nccl_net_ofi_isend_v5(void *sendComm, void* data, int size,
579+
int tag, void *mhandle, void** request)
580+
{
581+
return nccl_net_ofi_isend_v9(sendComm, data, static_cast<size_t>(size), tag, mhandle, request);
582+
}
583+
584+
ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size,
585+
int tag, void* mhandle, void** request)
580586
{
581587
nccl_net_ofi_send_comm_t *send_comm =
582-
(nccl_net_ofi_send_comm_t *)sComm;
588+
(nccl_net_ofi_send_comm_t *)sendComm;
583589
nccl_net_ofi_mr_handle_t *handle = (nccl_net_ofi_mr_handle_t *)mhandle;
584-
nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)req;
590+
nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)request;
585591

586592
/* Validate send_comm */
587593
if (OFI_UNLIKELY(send_comm == NULL)) {
@@ -604,35 +610,46 @@ ncclResult_t nccl_net_ofi_isend_v5(void *sComm, void* data, int size,
604610
return nccl_net_ofi_retval_translate(ret);
605611
}
606612

607-
608-
ncclResult_t nccl_net_ofi_isend_v9(void* sendComm, void* data, size_t size,
609-
int tag, void* mhandle, void** request)
613+
ncclResult_t nccl_net_ofi_irecv_v2(void* recvComm, void* data, int size,
614+
void* mhandle, void** request)
610615
{
611-
ncclResult_t validation_result = msg_length_verify_max_size(&size, 1);
612-
if (validation_result != ncclSuccess) {
613-
return check_return(validation_result);
614-
}
616+
int tag = 0;
615617

616-
return nccl_net_ofi_isend_v5(sendComm, data, (int)size, tag, mhandle, request);
618+
return nccl_net_ofi_irecv_v5(recvComm, 1, &data, &size, &tag, &mhandle, request);
617619
}
618620

619621

620-
ncclResult_t nccl_net_ofi_irecv_v2(void* recvComm, void* data, int size,
621-
void* mhandle, void** request)
622+
ncclResult_t nccl_net_ofi_irecv_v5(void* recvComm, int n, void** data, int* sizes,
623+
int *tags, void** mhandles, void** request)
622624
{
623-
int tag = 0;
625+
size_t castedSizes[NCCL_OFI_MAX_RECVS] = {0};
626+
for (int i = 0; i < n; i++) {
627+
castedSizes[i] = static_cast<size_t>(sizes[i]);
628+
}
624629

625-
return nccl_net_ofi_irecv_v5(recvComm, 1, &data, &size, &tag, &mhandle, request);
630+
return nccl_net_ofi_irecv_v9(recvComm, n, data, castedSizes, tags, mhandles, request);
626631
}
627632

628633

629-
ncclResult_t nccl_net_ofi_irecv_v5(void* rComm, int n, void** buffers, int* sizes,
630-
int *tags, void** mhandles, void** req)
634+
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
635+
size_t* sizes, int* tags, void** mhandles, void** request)
631636
{
637+
if (OFI_UNLIKELY(recvComm == NULL || data == NULL ||
638+
sizes == NULL || tags == NULL ||
639+
mhandles == NULL || request == NULL)) {
640+
NCCL_OFI_WARN("Invalid argument: NULL pointer detected");
641+
return check_return(ncclInvalidArgument);
642+
}
643+
644+
if (OFI_UNLIKELY(n <= 0 || n > NCCL_OFI_MAX_RECVS)) {
645+
NCCL_OFI_WARN("Invalid number of receives: %d (max: %d)", n, NCCL_OFI_MAX_RECVS);
646+
return check_return(ncclInvalidArgument);
647+
}
648+
632649
nccl_net_ofi_recv_comm_t *recv_comm =
633-
(nccl_net_ofi_recv_comm_t *)rComm;
650+
(nccl_net_ofi_recv_comm_t *)recvComm;
634651
nccl_net_ofi_mr_handle_t **handles = (nccl_net_ofi_mr_handle_t **)mhandles;
635-
nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)req;
652+
nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)request;
636653

637654
if (OFI_UNLIKELY(recv_comm == NULL)) {
638655
NCCL_OFI_WARN("Invalid communicator object provided");
@@ -661,40 +678,10 @@ ncclResult_t nccl_net_ofi_irecv_v5(void* rComm, int n, void** buffers, int* size
661678
return check_return(ncclInternalError);
662679
}
663680

664-
int ret = recv_comm->recv(recv_comm, n, buffers, sizes, tags, handles, base_req);
681+
int ret = recv_comm->recv(recv_comm, n, data, sizes, tags, handles, base_req);
665682
return nccl_net_ofi_retval_translate(ret);
666683
}
667684

668-
669-
ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
670-
size_t* sizes, int* tags, void** mhandles, void** request)
671-
{
672-
if (OFI_UNLIKELY(recvComm == NULL || data == NULL ||
673-
sizes == NULL || tags == NULL ||
674-
mhandles == NULL || request == NULL)) {
675-
NCCL_OFI_WARN("Invalid argument: NULL pointer detected");
676-
return check_return(ncclInvalidArgument);
677-
}
678-
679-
if (OFI_UNLIKELY(n <= 0 || n > NCCL_OFI_MAX_RECVS)) {
680-
NCCL_OFI_WARN("Invalid number of receives: %d (max: %d)", n, NCCL_OFI_MAX_RECVS);
681-
return check_return(ncclInvalidArgument);
682-
}
683-
684-
ncclResult_t validation_result = msg_length_verify_max_size(sizes, n);
685-
if (validation_result != ncclSuccess) {
686-
return check_return(validation_result);
687-
}
688-
689-
int sizesInt[NCCL_OFI_MAX_RECVS] = {0};
690-
for (int i = 0; i < n; i++) {
691-
sizesInt[i] = (int)sizes[i];
692-
}
693-
694-
return nccl_net_ofi_irecv_v5(recvComm, n, data, sizesInt, tags, mhandles, request);
695-
}
696-
697-
698685
ncclResult_t nccl_net_ofi_test_v2(void* req, int* done, int* size)
699686
{
700687
/* Validate request */

src/nccl_ofi_rdma.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3524,7 +3524,7 @@ static int process_cq_if_pending(nccl_net_ofi_rdma_ep_t *ep)
35243524
}
35253525

35263526
static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
3527-
int *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
3527+
size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
35283528
nccl_net_ofi_req_t **base_req)
35293529
{
35303530
int ret = 0;
@@ -5854,7 +5854,7 @@ static inline int check_post_rx_buff_req(nccl_net_ofi_rdma_req_t *rx_buff_req)
58545854
* @brief Send a message. This "interface function" is called, indirectly, from
58555855
* the application
58565856
*/
5857-
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
5857+
static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
58585858
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req)
58595859
{
58605860
int ret = 0;

src/nccl_ofi_sendrecv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ static inline nccl_net_ofi_sendrecv_req_t *sendrecv_allocate_req(nccl_ofi_freeli
10681068
}
10691069

10701070
static int sendrecv_recv_comm_recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
1071-
int *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
1071+
size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
10721072
nccl_net_ofi_req_t **base_req)
10731073
{
10741074
int ret = 0;
@@ -1846,7 +1846,7 @@ static int sendrecv_send_comm_dereg_mr(nccl_net_ofi_send_comm_t *send_comm,
18461846
domain->base.mr_cache);
18471847
}
18481848

1849-
static int sendrecv_send_comm_send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int tag,
1849+
static int sendrecv_send_comm_send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
18501850
nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req)
18511851
{
18521852
int ret = 0;

0 commit comments

Comments
 (0)