11/*
2- * Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
2+ * Copyright (c) 2024 NVIDIA Corporation. All rights reserved.
33 * Copyright (c) 2004-2023 The University of Tennessee and The University
44 * of Tennessee Research Foundation. All rights
55 * reserved.
@@ -36,7 +36,7 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
3636 mca_coll_base_module_t * module )
3737{
3838 mca_coll_accelerator_module_t * s = (mca_coll_accelerator_module_t * ) module ;
39- int rank = ( comm == NULL ) ? -1 : ompi_comm_rank (comm );
39+ int rank = ompi_comm_rank (comm );
4040 ptrdiff_t gap ;
4141 char * rbuf1 = NULL , * sbuf1 = NULL , * rbuf2 = NULL ;
4242 size_t bufsize ;
@@ -71,15 +71,9 @@ mca_coll_accelerator_reduce(const void *sbuf, void *rbuf, size_t count,
7171 rbuf2 = rbuf ; /* save away original buffer */
7272 rbuf = rbuf1 - gap ;
7373 }
74-
75- if ((comm == NULL ) && (root == -1 )) {
76- ompi_op_reduce (op , (void * )sbuf , rbuf , count , dtype );
77- rc = OMPI_SUCCESS ;
78- } else {
79- rc = s -> c_coll .coll_reduce ((void * ) sbuf , rbuf , count ,
80- dtype , op , root , comm ,
81- s -> c_coll .coll_reduce_module );
82- }
74+ rc = s -> c_coll .coll_reduce ((void * ) sbuf , rbuf , count ,
75+ dtype , op , root , comm ,
76+ s -> c_coll .coll_reduce_module );
8377
8478 if (NULL != sbuf1 ) {
8579 free (sbuf1 );
@@ -98,6 +92,53 @@ mca_coll_accelerator_reduce_local(const void *sbuf, void *rbuf, size_t count,
9892 struct ompi_op_t * op ,
9993 mca_coll_base_module_t * module )
10094{
101- return mca_coll_accelerator_reduce (sbuf , rbuf , count , dtype , op , -1 , NULL ,
102- module );
95+ ptrdiff_t gap ;
96+ char * rbuf1 = NULL , * sbuf1 = NULL , * rbuf2 = NULL ;
97+ size_t bufsize ;
98+ int rc ;
99+
100+ bufsize = opal_datatype_span (& dtype -> super , count , & gap );
101+
102+ rc = mca_coll_accelerator_check_buf ((void * )sbuf );
103+ if (rc < 0 ) {
104+ return rc ;
105+ }
106+
107+ if ((MPI_IN_PLACE != sbuf ) && (rc > 0 )) {
108+ sbuf1 = (char * )malloc (bufsize );
109+ if (NULL == sbuf1 ) {
110+ return OMPI_ERR_OUT_OF_RESOURCE ;
111+ }
112+ mca_coll_accelerator_memcpy (sbuf1 , sbuf , bufsize );
113+ sbuf = sbuf1 - gap ;
114+ }
115+
116+ rc = mca_coll_accelerator_check_buf (rbuf );
117+ if (rc < 0 ) {
118+ return rc ;
119+ }
120+
121+ if (rc > 0 ) {
122+ rbuf1 = (char * )malloc (bufsize );
123+ if (NULL == rbuf1 ) {
124+ if (NULL != sbuf1 ) free (sbuf1 );
125+ return OMPI_ERR_OUT_OF_RESOURCE ;
126+ }
127+ mca_coll_accelerator_memcpy (rbuf1 , rbuf , bufsize );
128+ rbuf2 = rbuf ; /* save away original buffer */
129+ rbuf = rbuf1 - gap ;
130+ }
131+
132+ ompi_op_reduce (op , (void * )sbuf , rbuf , count , dtype );
133+ rc = OMPI_SUCCESS ;
134+
135+ if (NULL != sbuf1 ) {
136+ free (sbuf1 );
137+ }
138+ if (NULL != rbuf1 ) {
139+ rbuf = rbuf2 ;
140+ mca_coll_accelerator_memcpy (rbuf , rbuf1 , bufsize );
141+ free (rbuf1 );
142+ }
143+ return rc ;
103144}
0 commit comments