4242#include "mpi.h"
4343#include "ompi/mca/mca.h"
4444#include "opal/util/output.h"
45+ #include "opal/mca/smsc/smsc.h"
4546#include "ompi/mca/coll/base/coll_base_functions.h"
4647#include "coll_han_trigger.h"
4748#include "ompi/mca/coll/han/coll_han_dynamic.h"
@@ -197,6 +198,7 @@ typedef struct mca_coll_han_op_module_name_t {
197198 mca_coll_han_op_up_low_module_name_t gatherv ;
198199 mca_coll_han_op_up_low_module_name_t scatter ;
199200 mca_coll_han_op_up_low_module_name_t scatterv ;
201+ mca_coll_han_op_up_low_module_name_t alltoall ;
200202} mca_coll_han_op_module_name_t ;
201203
202204/**
@@ -252,6 +254,13 @@ typedef struct mca_coll_han_component_t {
252254 uint32_t han_scatterv_up_module ;
253255 /* low level module for scatterv */
254256 uint32_t han_scatterv_low_module ;
257+
258+ /* low level module for alltoall */
259+ uint32_t han_alltoall_low_module ;
260+ /* alltoall: parallel stages */
261+ int32_t han_alltoall_pstages ;
262+
263+
255264 /* name of the modules */
256265 mca_coll_han_op_module_name_t han_op_module_name ;
257266 /* whether we need reproducible results
@@ -287,6 +296,7 @@ typedef struct mca_coll_han_single_collective_fallback_s
287296{
288297 union
289298 {
299+ mca_coll_base_module_alltoall_fn_t alltoall ;
290300 mca_coll_base_module_allgather_fn_t allgather ;
291301 mca_coll_base_module_allgatherv_fn_t allgatherv ;
292302 mca_coll_base_module_allreduce_fn_t allreduce ;
@@ -308,6 +318,7 @@ typedef struct mca_coll_han_single_collective_fallback_s
308318 */
309319typedef struct mca_coll_han_collectives_fallback_s
310320{
321+ mca_coll_han_single_collective_fallback_t alltoall ;
311322 mca_coll_han_single_collective_fallback_t allgather ;
312323 mca_coll_han_single_collective_fallback_t allgatherv ;
313324 mca_coll_han_single_collective_fallback_t allreduce ;
@@ -370,6 +381,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
370381 * Some defines to stick to the naming used in the other components in terms of
371382 * fallback routines
372383 */
384+ #define previous_alltoall fallback.alltoall.alltoall
385+ #define previous_alltoall_module fallback.alltoall.module
386+
373387#define previous_allgather fallback.allgather.allgather
374388#define previous_allgather_module fallback.allgather.module
375389
@@ -425,6 +439,7 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
425439 HAN_UNINSTALL_COLL_API(COMM, HANM, allreduce); \
426440 HAN_UNINSTALL_COLL_API(COMM, HANM, allgather); \
427441 HAN_UNINSTALL_COLL_API(COMM, HANM, allgatherv); \
442+ HAN_UNINSTALL_COLL_API(COMM, HANM, alltoall); \
428443 han_module->enabled = false; /* entire module set to pass-through from now on */ \
429444 } while (0 )
430445
@@ -485,6 +500,9 @@ mca_coll_han_get_all_coll_modules(struct ompi_communicator_t *comm,
485500 mca_coll_han_module_t * han_module );
486501
487502int
503+ mca_coll_han_alltoall_intra_dynamic (ALLTOALL_BASE_ARGS ,
504+ mca_coll_base_module_t * module );
505+ int
488506mca_coll_han_allgather_intra_dynamic (ALLGATHER_BASE_ARGS ,
489507 mca_coll_base_module_t * module );
490508int
@@ -532,4 +550,20 @@ coll_han_utils_gcd(const uint64_t *numerators, const size_t size);
532550int
533551coll_han_utils_create_contiguous_datatype (size_t count , const ompi_datatype_t * oldType ,
534552 ompi_datatype_t * * newType );
553+
554+ static inline struct mca_smsc_endpoint_t * mca_coll_han_get_smsc_endpoint (struct ompi_proc_t * proc ) {
555+ extern opal_mutex_t mca_coll_han_lock ;
556+ if (NULL == proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ]) {
557+ if (NULL == proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ]) {
558+ OPAL_THREAD_LOCK (& mca_coll_han_lock );
559+ if (NULL == proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ]) {
560+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ] = mca_smsc -> get_endpoint (& proc -> super );
561+ }
562+ OPAL_THREAD_UNLOCK (& mca_coll_han_lock );
563+ }
564+ }
565+
566+ return (struct mca_smsc_endpoint_t * ) proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_SMSC ];
567+ }
568+
535569#endif /* MCA_COLL_HAN_EXPORT_H */
0 commit comments