@@ -575,13 +575,19 @@ ncclResult_t nccl_net_ofi_isend_v2(void* sendComm, void* data, int size,
575
575
}
576
576
577
577
578
- ncclResult_t nccl_net_ofi_isend_v5 (void *sComm , void * data, int size,
579
- int tag, void *mhandle, void ** req)
578
+ ncclResult_t nccl_net_ofi_isend_v5 (void *sendComm, void * data, int size,
579
+ int tag, void *mhandle, void ** request)
580
+ {
581
+ return nccl_net_ofi_isend_v9 (sendComm, data, static_cast <size_t >(size), tag, mhandle, request);
582
+ }
583
+
584
+ ncclResult_t nccl_net_ofi_isend_v9 (void * sendComm, void * data, size_t size,
585
+ int tag, void * mhandle, void ** request)
580
586
{
581
587
nccl_net_ofi_send_comm_t *send_comm =
582
- (nccl_net_ofi_send_comm_t *)sComm ;
588
+ (nccl_net_ofi_send_comm_t *)sendComm ;
583
589
nccl_net_ofi_mr_handle_t *handle = (nccl_net_ofi_mr_handle_t *)mhandle;
584
- nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)req ;
590
+ nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)request ;
585
591
586
592
/* Validate send_comm */
587
593
if (OFI_UNLIKELY (send_comm == NULL )) {
@@ -604,35 +610,46 @@ ncclResult_t nccl_net_ofi_isend_v5(void *sComm, void* data, int size,
604
610
return nccl_net_ofi_retval_translate (ret);
605
611
}
606
612
607
-
608
- ncclResult_t nccl_net_ofi_isend_v9 (void * sendComm, void * data, size_t size,
609
- int tag, void * mhandle, void ** request)
613
+ ncclResult_t nccl_net_ofi_irecv_v2 (void * recvComm, void * data, int size,
614
+ void * mhandle, void ** request)
610
615
{
611
- ncclResult_t validation_result = msg_length_verify_max_size (&size, 1 );
612
- if (validation_result != ncclSuccess) {
613
- return check_return (validation_result);
614
- }
616
+ int tag = 0 ;
615
617
616
- return nccl_net_ofi_isend_v5 (sendComm, data, ( int ) size, tag, mhandle, request);
618
+ return nccl_net_ofi_irecv_v5 (recvComm, 1 , & data, & size, & tag, & mhandle, request);
617
619
}
618
620
619
621
620
- ncclResult_t nccl_net_ofi_irecv_v2 (void * recvComm, void * data, int size ,
621
- void * mhandle , void ** request)
622
+ ncclResult_t nccl_net_ofi_irecv_v5 (void * recvComm, int n, void ** data, int * sizes ,
623
+ int *tags, void ** mhandles , void ** request)
622
624
{
623
- int tag = 0 ;
625
+ size_t castedSizes[NCCL_OFI_MAX_RECVS] = {0 };
626
+ for (int i = 0 ; i < n; i++) {
627
+ castedSizes[i] = static_cast <size_t >(sizes[i]);
628
+ }
624
629
625
- return nccl_net_ofi_irecv_v5 (recvComm, 1 , & data, &size, &tag, &mhandle , request);
630
+ return nccl_net_ofi_irecv_v9 (recvComm, n, data, castedSizes, tags, mhandles , request);
626
631
}
627
632
628
633
629
- ncclResult_t nccl_net_ofi_irecv_v5 (void * rComm , int n, void ** buffers, int * sizes ,
630
- int * tags, void ** mhandles, void ** req )
634
+ ncclResult_t nccl_net_ofi_irecv_v9 (void * recvComm , int n, void ** data ,
635
+ size_t * sizes, int * tags, void ** mhandles, void ** request )
631
636
{
637
+ if (OFI_UNLIKELY (recvComm == NULL || data == NULL ||
638
+ sizes == NULL || tags == NULL ||
639
+ mhandles == NULL || request == NULL )) {
640
+ NCCL_OFI_WARN (" Invalid argument: NULL pointer detected" );
641
+ return check_return (ncclInvalidArgument);
642
+ }
643
+
644
+ if (OFI_UNLIKELY (n <= 0 || n > NCCL_OFI_MAX_RECVS)) {
645
+ NCCL_OFI_WARN (" Invalid number of receives: %d (max: %d)" , n, NCCL_OFI_MAX_RECVS);
646
+ return check_return (ncclInvalidArgument);
647
+ }
648
+
632
649
nccl_net_ofi_recv_comm_t *recv_comm =
633
- (nccl_net_ofi_recv_comm_t *)rComm ;
650
+ (nccl_net_ofi_recv_comm_t *)recvComm ;
634
651
nccl_net_ofi_mr_handle_t **handles = (nccl_net_ofi_mr_handle_t **)mhandles;
635
- nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)req ;
652
+ nccl_net_ofi_req_t **base_req = (nccl_net_ofi_req_t **)request ;
636
653
637
654
if (OFI_UNLIKELY (recv_comm == NULL )) {
638
655
NCCL_OFI_WARN (" Invalid communicator object provided" );
@@ -661,40 +678,10 @@ ncclResult_t nccl_net_ofi_irecv_v5(void* rComm, int n, void** buffers, int* size
661
678
return check_return (ncclInternalError);
662
679
}
663
680
664
- int ret = recv_comm->recv (recv_comm, n, buffers , sizes, tags, handles, base_req);
681
+ int ret = recv_comm->recv (recv_comm, n, data , sizes, tags, handles, base_req);
665
682
return nccl_net_ofi_retval_translate (ret);
666
683
}
667
684
668
-
669
- ncclResult_t nccl_net_ofi_irecv_v9 (void * recvComm, int n, void ** data,
670
- size_t * sizes, int * tags, void ** mhandles, void ** request)
671
- {
672
- if (OFI_UNLIKELY (recvComm == NULL || data == NULL ||
673
- sizes == NULL || tags == NULL ||
674
- mhandles == NULL || request == NULL )) {
675
- NCCL_OFI_WARN (" Invalid argument: NULL pointer detected" );
676
- return check_return (ncclInvalidArgument);
677
- }
678
-
679
- if (OFI_UNLIKELY (n <= 0 || n > NCCL_OFI_MAX_RECVS)) {
680
- NCCL_OFI_WARN (" Invalid number of receives: %d (max: %d)" , n, NCCL_OFI_MAX_RECVS);
681
- return check_return (ncclInvalidArgument);
682
- }
683
-
684
- ncclResult_t validation_result = msg_length_verify_max_size (sizes, n);
685
- if (validation_result != ncclSuccess) {
686
- return check_return (validation_result);
687
- }
688
-
689
- int sizesInt[NCCL_OFI_MAX_RECVS] = {0 };
690
- for (int i = 0 ; i < n; i++) {
691
- sizesInt[i] = (int )sizes[i];
692
- }
693
-
694
- return nccl_net_ofi_irecv_v5 (recvComm, n, data, sizesInt, tags, mhandles, request);
695
- }
696
-
697
-
698
685
ncclResult_t nccl_net_ofi_test_v2 (void * req, int * done, int * size)
699
686
{
700
687
/* Validate request */
0 commit comments