Skip to content

Commit 2a85f93

Browse files
[Core][Distributed] enable multiple tp group (#4512)
Co-authored-by: Zhuohan Li <[email protected]>
1 parent cf8cac8 commit 2a85f93

File tree

4 files changed

+43
-4
lines changed

4 files changed

+43
-4
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,24 @@ steps:
2525
- label: Distributed Comm Ops Test
2626
command: pytest -v -s test_comm_ops.py
2727
working_dir: "/vllm-workspace/tests/distributed"
28-
num_gpus: 2 # only support 1 or 2 for now.
28+
num_gpus: 2
2929

3030
- label: Distributed Tests
3131
working_dir: "/vllm-workspace/tests/distributed"
32-
num_gpus: 2 # only support 1 or 2 for now.
32+
num_gpus: 2
3333
commands:
34-
- pytest -v -s test_pynccl.py
3534
- pytest -v -s test_pynccl_library.py
3635
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
3736
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
3837
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
3938
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
4039

40+
- label: Distributed Tests (Multiple Groups)
41+
working_dir: "/vllm-workspace/tests/distributed"
42+
num_gpus: 4
43+
commands:
44+
- pytest -v -s test_pynccl.py
45+
4146
- label: Engine Test
4247
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
4348

.buildkite/test-template.j2

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ steps:
4545
plugins:
4646
- kubernetes:
4747
podSpec:
48+
{% if step.num_gpus %}
49+
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
50+
{% endif %}
4851
volumes:
4952
- name: dshm
5053
emptyDir:

tests/distributed/test_pynccl.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,34 @@ def test_pynccl():
5858
distributed_run(worker_fn, 2)
5959

6060

61+
@worker_fn_wrapper
62+
def multiple_tp_worker_fn():
63+
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
64+
groups = [
65+
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
66+
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
67+
]
68+
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
69+
comm = NCCLCommunicator(group=group, device=device)
70+
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
71+
# two groups can communicate independently
72+
if torch.distributed.get_rank() in [0, 1]:
73+
comm.all_reduce(tensor)
74+
comm.all_reduce(tensor)
75+
result = tensor.mean().cpu().item()
76+
assert result == 4
77+
else:
78+
comm.all_reduce(tensor)
79+
result = tensor.mean().cpu().item()
80+
assert result == 2
81+
82+
83+
@pytest.mark.skipif(torch.cuda.device_count() < 4,
84+
reason="Need at least 2 GPUs to run the test.")
85+
def test_pynccl_multiple_tp():
86+
distributed_run(worker_fn, 4)
87+
88+
6189
@worker_fn_wrapper
6290
def worker_fn_with_cudagraph():
6391
with torch.no_grad():

vllm/distributed/device_communicators/pynccl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,17 @@ def __init__(
232232
assert dist.get_backend(group) != dist.Backend.NCCL, (
233233
"NCCLCommunicator should be attached to a non-NCCL group.")
234234
self.group = group
235+
# note: this rank is the rank in the group
235236
self.rank = dist.get_rank(group)
236237
self.world_size = dist.get_world_size(group)
237238
if self.rank == 0:
238239
self.unique_id = ncclGetUniqueId()
239240
else:
240241
self.unique_id = NcclUniqueId()
241242
tensor = torch.ByteTensor(list(self.unique_id.internal))
242-
dist.broadcast(tensor, src=0, group=group)
243+
ranks = dist.get_process_group_ranks(group)
244+
# arg `src` in `broadcast` is the global rank
245+
dist.broadcast(tensor, src=ranks[0], group=group)
243246
byte_list = tensor.tolist()
244247
for i, byte in enumerate(byte_list):
245248
self.unique_id.internal[i] = byte

0 commit comments

Comments
 (0)