Skip to content

Commit 344a5d0

Browse files
authored
[Core][Distributed] enable allreduce for multiple tp groups (#4566)
1 parent 0f8a914 commit 344a5d0

File tree

4 files changed

+71
-22
lines changed

4 files changed

+71
-22
lines changed

tests/distributed/test_pynccl.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
import pytest
44
import torch
55

6+
import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
7+
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
68
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
79
ncclGetUniqueId)
8-
from vllm.distributed.parallel_state import init_distributed_environment
10+
from vllm.distributed.parallel_state import (
11+
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
12+
init_distributed_environment, with_pynccl_for_all_reduce)
913
from vllm.utils import update_environment_variables
1014

1115

@@ -67,7 +71,7 @@ def multiple_tp_worker_fn():
6771
]
6872
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
6973
comm = NCCLCommunicator(group=group, device=device)
70-
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
74+
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
7175
# two groups can communicate independently
7276
if torch.distributed.get_rank() in [0, 1]:
7377
comm.all_reduce(tensor)
@@ -81,9 +85,40 @@ def multiple_tp_worker_fn():
8185

8286

8387
@pytest.mark.skipif(torch.cuda.device_count() < 4,
84-
reason="Need at least 2 GPUs to run the test.")
88+
reason="Need at least 4 GPUs to run the test.")
8589
def test_pynccl_multiple_tp():
86-
distributed_run(worker_fn, 4)
90+
# this tests pynccl for multiple tp groups, in a standalone way
91+
# i.e. call `comm.all_reduce` directly
92+
distributed_run(multiple_tp_worker_fn, 4)
93+
94+
95+
@worker_fn_wrapper
96+
def multiple_tp_with_vllm_worker_fn():
97+
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
98+
torch.cuda.set_device(torch.distributed.get_rank())
99+
ensure_model_parallel_initialized(2, 2)
100+
pynccl_utils.init_process_group(
101+
group=get_tensor_model_parallel_cpu_group())
102+
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
103+
with with_pynccl_for_all_reduce():
104+
# two tp groups can communicate independently
105+
if torch.distributed.get_rank() in [0, 1]:
106+
tensor = tensor_model_parallel_all_reduce(tensor)
107+
tensor = tensor_model_parallel_all_reduce(tensor)
108+
result = tensor.mean().cpu().item()
109+
assert result == 4
110+
else:
111+
tensor = tensor_model_parallel_all_reduce(tensor)
112+
result = tensor.mean().cpu().item()
113+
assert result == 2
114+
115+
116+
@pytest.mark.skipif(torch.cuda.device_count() < 4,
117+
reason="Need at least 4 GPUs to run the test.")
118+
def test_pynccl_multiple_tp_with_vllm():
119+
# this tests pynccl for multiple tp groups, together with vllm
120+
# i.e. call `tensor_model_parallel_all_reduce`
121+
distributed_run(multiple_tp_with_vllm_worker_fn, 4)
87122

88123

89124
@worker_fn_wrapper

vllm/distributed/communication_op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
3434
if out is not None:
3535
return out
3636
if is_pynccl_enabled_for_all_reduce():
37-
# TODO: support multiple parallel groups.
3837
pynccl_utils.all_reduce(input_)
3938
else:
4039
torch.distributed.all_reduce(input_,

vllm/distributed/parallel_state.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
logger = init_logger(__name__)
1515

1616
# Tensor model parallel group that the current rank belongs to.
17-
_TENSOR_MODEL_PARALLEL_GROUP = None
17+
_TP_DEVICE_GROUP = None
18+
_TP_CPU_GROUP = None
1819
# Pipeline model parallel group that the current rank belongs to.
1920
_PIPELINE_MODEL_PARALLEL_GROUP = None
2021

@@ -132,15 +133,17 @@ def initialize_model_parallel(
132133
rank = torch.distributed.get_rank()
133134

134135
# Build the tensor model-parallel groups.
135-
global _TENSOR_MODEL_PARALLEL_GROUP
136-
assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
136+
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
137+
assert _TP_DEVICE_GROUP is None, (
137138
"tensor model parallel group is already initialized")
138139
for i in range(num_tensor_model_parallel_groups):
139140
ranks = range(i * tensor_model_parallel_size,
140141
(i + 1) * tensor_model_parallel_size)
141142
group = torch.distributed.new_group(ranks, backend=backend)
143+
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
142144
if rank in ranks:
143-
_TENSOR_MODEL_PARALLEL_GROUP = group
145+
_TP_DEVICE_GROUP = group
146+
_TP_CPU_GROUP = cpu_group
144147

145148
# Build the pipeline model-parallel groups.
146149
global _PIPELINE_MODEL_PARALLEL_GROUP
@@ -185,7 +188,7 @@ def ensure_model_parallel_initialized(
185188

186189
def model_parallel_is_initialized():
187190
"""Check if tensor and pipeline parallel groups are initialized."""
188-
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
191+
return (_TP_DEVICE_GROUP is not None
189192
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
190193

191194

@@ -197,9 +200,16 @@ def get_cpu_world_group():
197200

198201
def get_tensor_model_parallel_group():
199202
"""Get the tensor model parallel group the caller rank belongs to."""
200-
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
203+
assert _TP_DEVICE_GROUP is not None, (
201204
"tensor model parallel group is not initialized")
202-
return _TENSOR_MODEL_PARALLEL_GROUP
205+
return _TP_DEVICE_GROUP
206+
207+
208+
def get_tensor_model_parallel_cpu_group():
209+
"""Get the tensor model parallel cpu group the caller rank belongs to."""
210+
assert _TP_CPU_GROUP is not None, (
211+
"tensor model parallel cpu group is not initialized")
212+
return _TP_CPU_GROUP
203213

204214

205215
def get_pipeline_model_parallel_group():
@@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
277287

278288
def destroy_model_parallel():
279289
"""Set the groups to none and destroy them."""
280-
global _TENSOR_MODEL_PARALLEL_GROUP
281-
if _TENSOR_MODEL_PARALLEL_GROUP:
282-
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
283-
_TENSOR_MODEL_PARALLEL_GROUP = None
290+
global _TP_DEVICE_GROUP
291+
if _TP_DEVICE_GROUP:
292+
torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
293+
_TP_DEVICE_GROUP = None
294+
global _TP_CPU_GROUP
295+
if _TP_CPU_GROUP:
296+
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
297+
_TP_CPU_GROUP = None
284298
global _PIPELINE_MODEL_PARALLEL_GROUP
285299
if _PIPELINE_MODEL_PARALLEL_GROUP:
286300
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)

vllm/worker/worker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
VisionLanguageConfig)
1212
from vllm.distributed import (broadcast_tensor_dict,
1313
ensure_model_parallel_initialized,
14+
get_tensor_model_parallel_cpu_group,
1415
init_distributed_environment)
1516
from vllm.distributed.device_communicators import pynccl_utils
1617
from vllm.distributed.device_communicators.custom_all_reduce import (
@@ -288,6 +289,9 @@ def init_worker_distributed_environment(
288289
init_distributed_environment(parallel_config.world_size, rank,
289290
distributed_init_method, local_rank)
290291

292+
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
293+
parallel_config.pipeline_parallel_size)
294+
291295
if pynccl_utils.is_initialized():
292296
pynccl_world_size = pynccl_utils.get_world_size()
293297
if pynccl_world_size != parallel_config.world_size:
@@ -298,12 +302,9 @@ def init_worker_distributed_environment(
298302
elif parallel_config.world_size > 1:
299303
# NOTE(woosuk): We don't initialize pynccl process group when world size
300304
# is 1.
301-
# NOTE(kaichao): By default, pynccl will use information inside
302-
# `parallel_state` for initialization.
303-
pynccl_utils.init_process_group()
304-
305-
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
306-
parallel_config.pipeline_parallel_size)
305+
# NOTE(kaichao): By default, pynccl is initialized for tp group.
306+
pynccl_utils.init_process_group(
307+
group=get_tensor_model_parallel_cpu_group())
307308

308309
# Initialize a custom fast all-reduce implementation.
309310
if not parallel_config.disable_custom_all_reduce:

0 commit comments

Comments
 (0)