14
14
logger = init_logger (__name__ )
15
15
16
16
# Tensor model parallel group that the current rank belongs to.
17
- _TENSOR_MODEL_PARALLEL_GROUP = None
17
+ _TP_DEVICE_GROUP = None
18
+ _TP_CPU_GROUP = None
18
19
# Pipeline model parallel group that the current rank belongs to.
19
20
_PIPELINE_MODEL_PARALLEL_GROUP = None
20
21
@@ -132,15 +133,17 @@ def initialize_model_parallel(
132
133
rank = torch .distributed .get_rank ()
133
134
134
135
# Build the tensor model-parallel groups.
135
- global _TENSOR_MODEL_PARALLEL_GROUP
136
- assert _TENSOR_MODEL_PARALLEL_GROUP is None , (
136
+ global _TP_DEVICE_GROUP , _TP_CPU_GROUP
137
+ assert _TP_DEVICE_GROUP is None , (
137
138
"tensor model parallel group is already initialized" )
138
139
for i in range (num_tensor_model_parallel_groups ):
139
140
ranks = range (i * tensor_model_parallel_size ,
140
141
(i + 1 ) * tensor_model_parallel_size )
141
142
group = torch .distributed .new_group (ranks , backend = backend )
143
+ cpu_group = torch .distributed .new_group (ranks , backend = "gloo" )
142
144
if rank in ranks :
143
- _TENSOR_MODEL_PARALLEL_GROUP = group
145
+ _TP_DEVICE_GROUP = group
146
+ _TP_CPU_GROUP = cpu_group
144
147
145
148
# Build the pipeline model-parallel groups.
146
149
global _PIPELINE_MODEL_PARALLEL_GROUP
@@ -185,7 +188,7 @@ def ensure_model_parallel_initialized(
185
188
186
189
def model_parallel_is_initialized ():
187
190
"""Check if tensor and pipeline parallel groups are initialized."""
188
- return (_TENSOR_MODEL_PARALLEL_GROUP is not None
191
+ return (_TP_DEVICE_GROUP is not None
189
192
and _PIPELINE_MODEL_PARALLEL_GROUP is not None )
190
193
191
194
@@ -197,9 +200,16 @@ def get_cpu_world_group():
197
200
198
201
def get_tensor_model_parallel_group ():
199
202
"""Get the tensor model parallel group the caller rank belongs to."""
200
- assert _TENSOR_MODEL_PARALLEL_GROUP is not None , (
203
+ assert _TP_DEVICE_GROUP is not None , (
201
204
"tensor model parallel group is not initialized" )
202
- return _TENSOR_MODEL_PARALLEL_GROUP
205
+ return _TP_DEVICE_GROUP
206
+
207
+
208
+ def get_tensor_model_parallel_cpu_group ():
209
+ """Get the tensor model parallel cpu group the caller rank belongs to."""
210
+ assert _TP_CPU_GROUP is not None , (
211
+ "tensor model parallel cpu group is not initialized" )
212
+ return _TP_CPU_GROUP
203
213
204
214
205
215
def get_pipeline_model_parallel_group ():
@@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
277
287
278
288
def destroy_model_parallel ():
279
289
"""Set the groups to none and destroy them."""
280
- global _TENSOR_MODEL_PARALLEL_GROUP
281
- if _TENSOR_MODEL_PARALLEL_GROUP :
282
- torch .distributed .destroy_process_group (_TENSOR_MODEL_PARALLEL_GROUP )
283
- _TENSOR_MODEL_PARALLEL_GROUP = None
290
+ global _TP_DEVICE_GROUP
291
+ if _TP_DEVICE_GROUP :
292
+ torch .distributed .destroy_process_group (_TP_DEVICE_GROUP )
293
+ _TP_DEVICE_GROUP = None
294
+ global _TP_CPU_GROUP
295
+ if _TP_CPU_GROUP :
296
+ torch .distributed .destroy_process_group (_TP_CPU_GROUP )
297
+ _TP_CPU_GROUP = None
284
298
global _PIPELINE_MODEL_PARALLEL_GROUP
285
299
if _PIPELINE_MODEL_PARALLEL_GROUP :
286
300
torch .distributed .destroy_process_group (_PIPELINE_MODEL_PARALLEL_GROUP )
0 commit comments