|
2 | 2 | * Copyright (c) 2014-2017 The University of Tennessee and The University |
3 | 3 | * of Tennessee Research Foundation. All rights |
4 | 4 | * reserved. |
5 | | - * Copyright (c) 2014 NVIDIA Corporation. All rights reserved. |
| 5 | + * Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved. |
6 | 6 | * Copyright (c) 2019 Research Organization for Information Science |
7 | 7 | * and Technology (RIST). All rights reserved. |
8 | 8 | * Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. |
|
32 | 32 | #include "ompi/mca/coll/base/base.h" |
33 | 33 | #include "coll_accelerator.h" |
34 | 34 |
|
| 35 | +static int |
| 36 | +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, |
| 37 | + struct ompi_communicator_t *comm); |
| 38 | +static int |
| 39 | +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, |
| 40 | + struct ompi_communicator_t *comm); |
35 | 41 |
|
36 | 42 | static void mca_coll_accelerator_module_construct(mca_coll_accelerator_module_t *module) |
37 | 43 | { |
38 | 44 | memset(&(module->c_coll), 0, sizeof(module->c_coll)); |
39 | 45 | } |
40 | 46 |
|
41 | | -static void mca_coll_accelerator_module_destruct(mca_coll_accelerator_module_t *module) |
42 | | -{ |
43 | | - OBJ_RELEASE(module->c_coll.coll_allreduce_module); |
44 | | - OBJ_RELEASE(module->c_coll.coll_reduce_module); |
45 | | - OBJ_RELEASE(module->c_coll.coll_reduce_scatter_block_module); |
46 | | - OBJ_RELEASE(module->c_coll.coll_scatter_module); |
47 | | - /* If the exscan module is not NULL, then this was an |
48 | | - intracommunicator, and therefore scan will have a module as |
49 | | - well. */ |
50 | | - if (NULL != module->c_coll.coll_exscan_module) { |
51 | | - OBJ_RELEASE(module->c_coll.coll_exscan_module); |
52 | | - OBJ_RELEASE(module->c_coll.coll_scan_module); |
53 | | - } |
54 | | -} |
55 | | - |
56 | 47 | OBJ_CLASS_INSTANCE(mca_coll_accelerator_module_t, mca_coll_base_module_t, |
57 | 48 | mca_coll_accelerator_module_construct, |
58 | | - mca_coll_accelerator_module_destruct); |
| 49 | + NULL); |
59 | 50 |
|
60 | 51 |
|
61 | 52 | /* |
@@ -99,66 +90,82 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm, |
99 | 90 |
|
100 | 91 | /* Choose whether to use [intra|inter] */ |
101 | 92 | accelerator_module->super.coll_module_enable = mca_coll_accelerator_module_enable; |
| 93 | + accelerator_module->super.coll_module_disable = mca_coll_accelerator_module_disable; |
102 | 94 |
|
103 | | - accelerator_module->super.coll_allgather = NULL; |
104 | | - accelerator_module->super.coll_allgatherv = NULL; |
105 | 95 | accelerator_module->super.coll_allreduce = mca_coll_accelerator_allreduce; |
106 | | - accelerator_module->super.coll_alltoall = NULL; |
107 | | - accelerator_module->super.coll_alltoallv = NULL; |
108 | | - accelerator_module->super.coll_alltoallw = NULL; |
109 | | - accelerator_module->super.coll_barrier = NULL; |
110 | | - accelerator_module->super.coll_bcast = NULL; |
111 | | - accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; |
112 | | - accelerator_module->super.coll_gather = NULL; |
113 | | - accelerator_module->super.coll_gatherv = NULL; |
114 | 96 | accelerator_module->super.coll_reduce = mca_coll_accelerator_reduce; |
115 | | - accelerator_module->super.coll_reduce_scatter = NULL; |
116 | 97 | accelerator_module->super.coll_reduce_scatter_block = mca_coll_accelerator_reduce_scatter_block; |
117 | | - accelerator_module->super.coll_scan = mca_coll_accelerator_scan; |
118 | | - accelerator_module->super.coll_scatter = NULL; |
119 | | - accelerator_module->super.coll_scatterv = NULL; |
| 98 | + if (!OMPI_COMM_IS_INTER(comm)) { |
| 99 | + accelerator_module->super.coll_scan = mca_coll_accelerator_scan; |
| 100 | + accelerator_module->super.coll_exscan = mca_coll_accelerator_exscan; |
| 101 | + } |
120 | 102 |
|
121 | 103 | return &(accelerator_module->super); |
122 | 104 | } |
123 | 105 |
|
124 | 106 |
|
| 107 | +#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \ |
| 108 | + do \ |
| 109 | + { \ |
| 110 | + if ((__comm)->c_coll->coll_##__api) \ |
| 111 | + { \ |
| 112 | + MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ |
| 113 | + MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \ |
| 114 | + } \ |
| 115 | + else \ |
| 116 | + { \ |
| 117 | + opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \ |
| 118 | + "cuda", #__api, ompi_process_info.nodename, \ |
| 119 | + mca_coll_accelerator_component.priority); \ |
| 120 | + } \ |
| 121 | + } while (0) |
| 122 | + |
| 123 | +#define ACCELERATOR_UNINSTALL_COLL_API(__comm, __module, __api) \ |
| 124 | + do \ |
| 125 | + { \ |
| 126 | + if (&(__module)->super == (__comm)->c_coll->coll_##__api##_module) { \ |
| 127 | + MCA_COLL_INSTALL_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \ |
| 128 | + (__module)->c_coll.coll_##__api##_module = NULL; \ |
| 129 | + (__module)->c_coll.coll_##__api = NULL; \ |
| 130 | + } \ |
| 131 | + } while (0) |
| 132 | + |
125 | 133 | /* |
126 | | - * Init module on the communicator |
| 134 | + * Init/Fini module on the communicator |
127 | 135 | */ |
128 | | -int mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, |
129 | | - struct ompi_communicator_t *comm) |
| 136 | +static int |
| 137 | +mca_coll_accelerator_module_enable(mca_coll_base_module_t *module, |
| 138 | + struct ompi_communicator_t *comm) |
130 | 139 | { |
131 | | - bool good = true; |
132 | | - char *msg = NULL; |
133 | 140 | mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; |
134 | 141 |
|
135 | | -#define CHECK_AND_RETAIN(src, dst, name) \ |
136 | | - if (NULL == (src)->c_coll->coll_ ## name ## _module) { \ |
137 | | - good = false; \ |
138 | | - msg = #name; \ |
139 | | - } else if (good) { \ |
140 | | - (dst)->c_coll.coll_ ## name ## _module = (src)->c_coll->coll_ ## name ## _module; \ |
141 | | - (dst)->c_coll.coll_ ## name = (src)->c_coll->coll_ ## name; \ |
142 | | - OBJ_RETAIN((src)->c_coll->coll_ ## name ## _module); \ |
143 | | - } |
144 | | - |
145 | | - CHECK_AND_RETAIN(comm, s, allreduce); |
146 | | - CHECK_AND_RETAIN(comm, s, reduce); |
147 | | - CHECK_AND_RETAIN(comm, s, reduce_scatter_block); |
148 | | - CHECK_AND_RETAIN(comm, s, scatter); |
| 142 | + ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce); |
| 143 | + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce); |
| 144 | + ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block); |
149 | 145 | if (!OMPI_COMM_IS_INTER(comm)) { |
150 | 146 | /* MPI does not define scan/exscan on intercommunicators */ |
151 | | - CHECK_AND_RETAIN(comm, s, exscan); |
152 | | - CHECK_AND_RETAIN(comm, s, scan); |
| 147 | + ACCELERATOR_INSTALL_COLL_API(comm, s, exscan); |
| 148 | + ACCELERATOR_INSTALL_COLL_API(comm, s, scan); |
153 | 149 | } |
154 | 150 |
|
155 | | - /* All done */ |
156 | | - if (good) { |
157 | | - return OMPI_SUCCESS; |
158 | | - } |
159 | | - opal_show_help("help-mpi-coll-accelerator.txt", "missing collective", true, |
160 | | - ompi_process_info.nodename, |
161 | | - mca_coll_accelerator_component.priority, msg); |
162 | | - return OMPI_ERR_NOT_FOUND; |
| 151 | + return OMPI_SUCCESS; |
163 | 152 | } |
164 | 153 |
|
| 154 | +static int |
| 155 | +mca_coll_accelerator_module_disable(mca_coll_base_module_t *module, |
| 156 | + struct ompi_communicator_t *comm) |
| 157 | +{ |
| 158 | + mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module; |
| 159 | + |
| 160 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, allreduce); |
| 161 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce); |
| 162 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, reduce_scatter_block); |
| 163 | + if (!OMPI_COMM_IS_INTER(comm)) |
| 164 | + { |
| 165 | + /* MPI does not define scan/exscan on intercommunicators */ |
| 166 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, exscan); |
| 167 | + ACCELERATOR_UNINSTALL_COLL_API(comm, s, scan); |
| 168 | + } |
| 169 | + |
| 170 | + return OMPI_SUCCESS; |
| 171 | +} |
0 commit comments