Skip to content

Commit ea3890a

Browse files
authored
[Core][Distributed] code deduplication in tp&pp with coordinator(#5293)
[Core][Distributed] add coordinator to reduce code duplication in tp and pp (#5293)
1 parent 2135cac commit ea3890a

File tree

12 files changed

+622
-582
lines changed

12 files changed

+622
-582
lines changed

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
from vllm import LLM, SamplingParams
1717
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
18-
from vllm.distributed import destroy_model_parallel
18+
from vllm.distributed import (destroy_distributed_environment,
19+
destroy_model_parallel)
1920
from vllm.inputs import TextPrompt
2021
from vllm.logger import init_logger
2122
from vllm.multimodal import MultiModalData
@@ -54,6 +55,7 @@ def _read_prompts(filename: str) -> List[str]:
5455

5556
def cleanup():
5657
destroy_model_parallel()
58+
destroy_distributed_environment()
5759
with contextlib.suppress(AssertionError):
5860
torch.distributed.destroy_process_group()
5961
gc.collect()

tests/distributed/test_custom_all_reduce.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch.distributed as dist
88

99
from vllm.distributed.communication_op import ( # noqa
10-
graph_capture, tensor_model_parallel_all_reduce)
10+
tensor_model_parallel_all_reduce)
1111
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
12-
get_tp_ca_communicator)
12+
get_tp_group, graph_capture)
1313

1414
from ..utils import (init_test_distributed_environment,
1515
multi_process_tensor_parallel)
@@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
9191
# communicate independently
9292
num_communication = rank // tp_size + 1
9393
sz = 1024
94-
fa = get_tp_ca_communicator()
94+
fa = get_tp_group().ca_comm
9595
inp = torch.ones(sz, dtype=torch.float32, device=device)
9696
out = inp
9797
for _ in range(num_communication):

tests/distributed/test_pynccl.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import torch.distributed
77

88
from vllm.distributed.communication_op import ( # noqa
9-
graph_capture, tensor_model_parallel_all_reduce)
9+
tensor_model_parallel_all_reduce)
1010
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
1111
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
1212
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
13+
get_world_group, graph_capture,
1314
init_distributed_environment)
1415
from vllm.utils import update_environment_variables
1516

@@ -53,7 +54,8 @@ def wrapped_fn(env):
5354

5455
@worker_fn_wrapper
5556
def worker_fn():
56-
pynccl_comm = PyNcclCommunicator()
57+
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
58+
device=get_world_group().device)
5759
tensor = torch.ones(16, 1024, 1024,
5860
dtype=torch.float32).cuda(pynccl_comm.rank)
5961
with pynccl_comm.change_state(enable=True):
@@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
129131
def worker_fn_with_cudagraph():
130132
with torch.no_grad():
131133
graph = torch.cuda.CUDAGraph()
132-
pynccl_comm = PyNcclCommunicator()
134+
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
135+
device=get_world_group().device)
133136
# run something in the default stream to initialize torch engine
134137
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
135138
torch.cuda.synchronize()
@@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
154157

155158
@worker_fn_wrapper
156159
def send_recv_worker_fn():
157-
pynccl_comm = PyNcclCommunicator()
160+
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
161+
device=get_world_group().device)
158162
if pynccl_comm.rank == 0:
159163
tensor = torch.ones(16, 1024, 1024,
160164
dtype=torch.float32).cuda(pynccl_comm.rank)

tests/lora/conftest.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212

1313
import vllm
1414
from vllm.config import LoRAConfig
15-
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
15+
from vllm.distributed import (destroy_distributed_environment,
16+
destroy_model_parallel,
17+
init_distributed_environment,
18+
initialize_model_parallel)
1619
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1720
MergedColumnParallelLinear,
1821
RowParallelLinear)
@@ -35,6 +38,7 @@
3538

3639
def cleanup():
3740
destroy_model_parallel()
41+
destroy_distributed_environment()
3842
with contextlib.suppress(AssertionError):
3943
torch.distributed.destroy_process_group()
4044
gc.collect()
@@ -64,15 +68,14 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
6468

6569
@pytest.fixture
6670
def dist_init():
67-
if not torch.distributed.is_initialized():
68-
temp_file = tempfile.mkstemp()[1]
69-
torch.distributed.init_process_group(
70-
backend="nccl",
71-
world_size=1,
72-
rank=0,
73-
init_method=f"file://{temp_file}",
74-
)
75-
torch.distributed.all_reduce(torch.zeros(1).cuda())
71+
temp_file = tempfile.mkstemp()[1]
72+
init_distributed_environment(
73+
world_size=1,
74+
rank=0,
75+
distributed_init_method=f"file://{temp_file}",
76+
local_rank=0,
77+
backend="nccl",
78+
)
7679
initialize_model_parallel(1, 1)
7780
yield
7881
cleanup()

tests/worker/test_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22
import torch
33

4-
from vllm.distributed.parallel_state import init_distributed_environment
4+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
5+
init_distributed_environment)
56
from vllm.engine.arg_utils import EngineArgs
67
from vllm.model_executor.sampling_metadata import SamplingMetadata
78
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
@@ -292,6 +293,7 @@ def distributed_init():
292293
rank=0,
293294
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
294295
local_rank=0)
296+
ensure_model_parallel_initialized(1, 1)
295297

296298

297299
@pytest.mark.parametrize("batch_size", list(range(2, 128)))

vllm/attention/backends/pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(
110110
raise NotImplementedError("TPU version must be 4 or higher.")
111111

112112
self.megacore_mode = None
113-
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
113+
tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower()
114114
if not tpu_type.endswith("lite"):
115115
if self.num_kv_heads % 2 == 0:
116116
self.megacore_mode = "kv_head"

0 commit comments

Comments
 (0)