@@ -28,7 +28,8 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
2828int mca_coll_acoll_allreduce_small_msgs_h (const void * sbuf , void * rbuf , size_t count ,
2929 struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
3030 struct ompi_communicator_t * comm ,
31- mca_coll_base_module_t * module , int intra );
31+ mca_coll_base_module_t * module ,
32+ coll_acoll_subcomms_t * subc , int intra );
3233
3334
3435static inline int coll_allreduce_decision_fixed (int comm_size , size_t msg_size )
@@ -52,16 +53,13 @@ static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
5253static inline int mca_coll_acoll_reduce_xpmem_h (const void * sbuf , void * rbuf , size_t count ,
5354 struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
5455 struct ompi_communicator_t * comm ,
55- mca_coll_base_module_t * module )
56+ mca_coll_base_module_t * module ,
57+ coll_acoll_subcomms_t * subc )
5658{
5759 int size ;
5860 size_t total_dsize , dsize ;
59- mca_coll_acoll_module_t * acoll_module = (mca_coll_acoll_module_t * ) module ;
6061
61- coll_acoll_subcomms_t * subc ;
62- int cid = ompi_comm_get_local_cid (comm );
63- subc = & acoll_module -> subc [cid ];
64- coll_acoll_init (module , comm , subc -> data );
62+ coll_acoll_init (module , comm , subc -> data , subc );
6563 coll_acoll_data_t * data = subc -> data ;
6664 if (NULL == data ) {
6765 return -1 ;
@@ -188,16 +186,13 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
188186 struct ompi_datatype_t * dtype ,
189187 struct ompi_op_t * op ,
190188 struct ompi_communicator_t * comm ,
191- mca_coll_base_module_t * module )
189+ mca_coll_base_module_t * module ,
190+ coll_acoll_subcomms_t * subc )
192191{
193192 int size ;
194193 size_t total_dsize , dsize ;
195- mca_coll_acoll_module_t * acoll_module = (mca_coll_acoll_module_t * ) module ;
196194
197- coll_acoll_subcomms_t * subc ;
198- int cid = ompi_comm_get_local_cid (comm );
199- subc = & acoll_module -> subc [cid ];
200- coll_acoll_init (module , comm , subc -> data );
195+ coll_acoll_init (module , comm , subc -> data , subc );
201196 coll_acoll_data_t * data = subc -> data ;
202197 if (NULL == data ) {
203198 return -1 ;
@@ -361,15 +356,13 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
361356int mca_coll_acoll_allreduce_small_msgs_h (const void * sbuf , void * rbuf , size_t count ,
362357 struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
363358 struct ompi_communicator_t * comm ,
364- mca_coll_base_module_t * module , int intra )
359+ mca_coll_base_module_t * module ,
360+ coll_acoll_subcomms_t * subc , int intra )
365361{
366362 size_t dsize ;
367363 int err = MPI_SUCCESS ;
368- mca_coll_acoll_module_t * acoll_module = (mca_coll_acoll_module_t * ) module ;
369- coll_acoll_subcomms_t * subc ;
370- int cid = ompi_comm_get_local_cid (comm );
371- subc = & acoll_module -> subc [cid ];
372- coll_acoll_init (module , comm , subc -> data );
364+
365+ coll_acoll_init (module , comm , subc -> data , subc );
373366 coll_acoll_data_t * data = subc -> data ;
374367 if (NULL == data ) {
375368 return -1 ;
@@ -385,7 +378,6 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c
385378
386379 int l1_local_rank = data -> l1_local_rank ;
387380 int l2_local_rank = data -> l2_local_rank ;
388- int comm_id = ompi_comm_get_local_cid (comm );
389381
390382 int offset1 = data -> offset [0 ];
391383 int offset2 = data -> offset [1 ];
@@ -441,8 +433,8 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c
441433 }
442434 }
443435
444- if (intra && (ompi_comm_size (acoll_module -> subc [ comm_id ]. numa_comm ) > 1 )) {
445- err = mca_coll_acoll_bcast (rbuf , count , dtype , 0 , acoll_module -> subc [ comm_id ]. numa_comm , module );
436+ if (intra && (ompi_comm_size (subc -> numa_comm ) > 1 )) {
437+ err = mca_coll_acoll_bcast (rbuf , count , dtype , 0 , subc -> numa_comm , module );
446438 }
447439 return err ;
448440}
@@ -466,25 +458,23 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
466458 return MPI_SUCCESS ;
467459 }
468460
469- coll_acoll_subcomms_t * subc ;
470- int cid = ompi_comm_get_local_cid (comm );
471- subc = & acoll_module -> subc [cid ];
472-
473461 /* Falling back to recursivedoubling for non-commutative operators to be safe */
474462 if (!ompi_op_is_commute (op )) {
475463 return ompi_coll_base_allreduce_intra_recursivedoubling (sbuf , rbuf , count , dtype , op , comm ,
476464 module );
477465 }
478466
479- /* Fallback to knomial if cid is beyond supported limit */
480- if (cid >= MCA_COLL_ACOLL_MAX_CID ) {
467+ /* Obtain the subcomms structure */
468+ coll_acoll_subcomms_t * subc = NULL ;
469+ err = check_and_create_subc (comm , acoll_module , & subc );
470+
471+ /* Fallback to knomial if subc is not obtained */
472+ if (NULL == subc ) {
481473 return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype , op , comm ,
482474 module );
483475 }
484-
485- subc = & acoll_module -> subc [cid ];
486476 if (!subc -> initialized ) {
487- err = mca_coll_acoll_comm_split_init (comm , acoll_module , 0 );
477+ err = mca_coll_acoll_comm_split_init (comm , acoll_module , subc , 0 );
488478 if (MPI_SUCCESS != err )
489479 return err ;
490480 }
@@ -499,7 +489,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
499489 comm , module );
500490 } else if (total_dsize < 512 ) {
501491 return mca_coll_acoll_allreduce_small_msgs_h (sbuf , rbuf , count , dtype , op , comm , module ,
502- 1 );
492+ subc , 1 );
503493 } else if (total_dsize <= 2048 ) {
504494 return ompi_coll_base_allreduce_intra_recursivedoubling (sbuf , rbuf , count , dtype , op ,
505495 comm , module );
@@ -517,7 +507,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
517507 } else if (total_dsize < 4194304 ) {
518508#ifdef HAVE_XPMEM_H
519509 if (((subc -> xpmem_use_sr_buf != 0 ) || (subc -> xpmem_buf_size > 2 * total_dsize )) && (subc -> without_xpmem != 1 )) {
520- return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module );
510+ return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
521511 } else {
522512 return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
523513 op , comm , module );
@@ -529,7 +519,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
529519 } else if (total_dsize <= 16777216 ) {
530520#ifdef HAVE_XPMEM_H
531521 if (((subc -> xpmem_use_sr_buf != 0 ) || (subc -> xpmem_buf_size > 2 * total_dsize )) && (subc -> without_xpmem != 1 )) {
532- mca_coll_acoll_reduce_xpmem_h (sbuf , rbuf , count , dtype , op , comm , module );
522+ mca_coll_acoll_reduce_xpmem_h (sbuf , rbuf , count , dtype , op , comm , module , subc );
533523 return mca_coll_acoll_bcast (rbuf , count , dtype , 0 , comm , module );
534524 } else {
535525 return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
@@ -542,7 +532,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
542532 } else {
543533#ifdef HAVE_XPMEM_H
544534 if (((subc -> xpmem_use_sr_buf != 0 ) || (subc -> xpmem_buf_size > 2 * total_dsize )) && (subc -> without_xpmem != 1 )) {
545- return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module );
535+ return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
546536 } else {
547537 return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
548538 op , comm , module );
0 commit comments