Skip to content

Commit 368ceb4

Browse files
committed
format fix
Signed-off-by: Yongji Wu <[email protected]>
1 parent 194bad1 commit 368ceb4

19 files changed

+668
-326
lines changed

vllm/config/parallel.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
298298
self.data_parallel_rank,
299299
self.data_parallel_size,
300300
backend="gloo",
301-
return_store=return_store
301+
return_store=return_store,
302302
)
303303
except DistNetworkError as e:
304304
# We only want to retry when the root cause is EADDRINUSE.
@@ -419,19 +419,31 @@ def __post_init__(self) -> None:
419419
if self.enable_elastic_ep:
420420
num_world_groups = 1
421421
num_dp_groups = max(1, self.world_size_across_dp // self.data_parallel_size)
422-
num_ep_groups = max(1, self.world_size_across_dp // (self.data_parallel_size * self.tensor_parallel_size))
422+
num_ep_groups = max(
423+
1,
424+
self.world_size_across_dp
425+
// (self.data_parallel_size * self.tensor_parallel_size),
426+
)
423427

424428
total_ports_needed = (num_world_groups + num_dp_groups + num_ep_groups) * 3
425429

426430
if not self._stateless_world_group_port_list:
427431
all_ports = get_open_ports_list(total_ports_needed + 5)
428432
self._data_parallel_master_port_list = all_ports[-5:]
429433
all_ports = all_ports[:-5]
430-
self._stateless_world_group_port_list = [all_ports[i:i+3] for i in range(0, num_world_groups * 3, 3)]
434+
self._stateless_world_group_port_list = [
435+
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
436+
]
431437
start_idx = num_world_groups * 3
432-
self._stateless_dp_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)]
438+
self._stateless_dp_group_port_list = [
439+
all_ports[i : i + 3]
440+
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
441+
]
433442
start_idx += num_dp_groups * 3
434-
self._stateless_ep_group_port_list = [all_ports[i:i+3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)]
443+
self._stateless_ep_group_port_list = [
444+
all_ports[i : i + 3]
445+
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
446+
]
435447

436448
if self.data_parallel_size_local > self.data_parallel_size:
437449
raise ValueError(

vllm/distributed/device_communicators/all2all.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def get_handle(self, kwargs):
185185
logger.debug("PPLX NVSHMEM UID = %s", uid)
186186
nvshmem_init(uid, self.rank, self.world_size)
187187
self.nvshmem_initialized = True
188-
188+
189189
import pplx_kernels as pplx
190190

191191
return self.handle_cache.get_or_create(
@@ -381,11 +381,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
381381
All2All communication based on flashinfer kernels.
382382
"""
383383

384-
def __init__(self, cpu_group):
384+
def __init__(self, cpu_group, tcp_store_group=None):
385385
assert has_flashinfer_all2all(), (
386386
"flashinfer all2all module not found. Please install/check flashinfer"
387387
) # noqa
388-
super().__init__(cpu_group)
388+
super().__init__(cpu_group, tcp_store_group)
389389
logger.debug(
390390
"Initialize for flashinfer All2All rank=%d, world size=%d",
391391
self.rank,

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def __init__(self, cpu_group, tcp_store_group=None):
5757
if tcp_store_group is None:
5858
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
5959
else:
60-
self.internode = not all(in_the_same_node_as(tcp_store_group, source_rank=0))
60+
self.internode = not all(
61+
in_the_same_node_as(tcp_store_group, source_rank=0)
62+
)
6163

6264
def get_handle(self, kwargs):
6365
# get a handle for the all2all communication,
@@ -104,7 +106,7 @@ def __init__(
104106
device_group: Optional[ProcessGroup] = None,
105107
unique_name: str = "",
106108
global_ranks: Optional[list[int]] = None,
107-
global_world_size: Optional[int] = None
109+
global_world_size: Optional[int] = None,
108110
):
109111
self.device = device or torch.device("cpu")
110112
self.cpu_group = cpu_group
@@ -113,12 +115,15 @@ def __init__(
113115

114116
# Check if this is a stateless process group
115117
from torch.distributed.distributed_c10d import _world
118+
116119
is_stateless = _world.pg_map.get(cpu_group, None) is None
117120

118121
if is_stateless:
119122
# For stateless groups, we can't use torch.distributed methods
120123
self.rank = cpu_group.rank()
121124
self.world_size = cpu_group.size()
125+
assert global_ranks is not None
126+
assert global_world_size is not None
122127
self.ranks = global_ranks
123128
self.global_rank = self.ranks[self.rank]
124129
self.global_world_size = global_world_size
@@ -270,6 +275,13 @@ def recv(
270275
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
271276
return tensor
272277

278+
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
279+
"""Broadcast a tensor from source rank to all ranks."""
280+
if self.world_size == 1:
281+
return tensor
282+
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
283+
return tensor
284+
273285
def destroy(self):
274286
pass
275287

@@ -313,3 +325,6 @@ def combine(
313325
This is a no-op in the base class.
314326
"""
315327
return hidden_states
328+
329+
def batch_isend_irecv(self, p2p_ops: list):
330+
raise NotImplementedError

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from vllm.logger import init_logger
1818
from vllm.platforms import current_platform
1919

20-
from .base_device_communicator import DeviceCommunicatorBase
2120
from ..utils import StatelessProcessGroup
21+
from .base_device_communicator import DeviceCommunicatorBase
2222

2323
logger = init_logger(__name__)
2424

@@ -32,9 +32,16 @@ def __init__(
3232
unique_name: str = "",
3333
global_ranks: Optional[list[int]] = None,
3434
global_world_size: Optional[int] = None,
35-
tcp_store_group: Optional[StatelessProcessGroup] = None
35+
tcp_store_group: Optional[StatelessProcessGroup] = None,
3636
):
37-
super().__init__(cpu_group, device, device_group, unique_name, global_ranks, global_world_size)
37+
super().__init__(
38+
cpu_group,
39+
device,
40+
device_group,
41+
unique_name,
42+
global_ranks,
43+
global_world_size,
44+
)
3845
if "tp" not in unique_name:
3946
# custom allreduce or torch symm mem can be used only by tp
4047
use_custom_allreduce = False
@@ -99,32 +106,44 @@ def __init__(
99106
if all2all_backend == "naive":
100107
from .all2all import NaiveAll2AllManager
101108

102-
self.all2all_manager = NaiveAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
109+
self.all2all_manager = NaiveAll2AllManager(
110+
self.cpu_group, tcp_store_group=tcp_store_group
111+
)
103112
logger.info("Using naive all2all manager.")
104113
elif all2all_backend == "allgather_reducescatter":
105114
from .all2all import AgRsAll2AllManager
106115

107-
self.all2all_manager = AgRsAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
116+
self.all2all_manager = AgRsAll2AllManager(
117+
self.cpu_group, tcp_store_group=tcp_store_group
118+
)
108119
logger.info("Using AllGather-ReduceScatter all2all manager.")
109120
elif all2all_backend == "pplx":
110121
from .all2all import PPLXAll2AllManager
111122

112-
self.all2all_manager = PPLXAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
123+
self.all2all_manager = PPLXAll2AllManager(
124+
self.cpu_group, tcp_store_group=tcp_store_group
125+
)
113126
logger.info("Using PPLX all2all manager.")
114127
elif all2all_backend == "deepep_high_throughput":
115128
from .all2all import DeepEPHTAll2AllManager
116129

117-
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
130+
self.all2all_manager = DeepEPHTAll2AllManager(
131+
self.cpu_group, tcp_store_group=tcp_store_group
132+
)
118133
logger.info("Using DeepEP High-Throughput all2all manager.")
119134
elif all2all_backend == "deepep_low_latency":
120135
from .all2all import DeepEPLLAll2AllManager
121136

122-
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group, tcp_store_group=tcp_store_group)
137+
self.all2all_manager = DeepEPLLAll2AllManager(
138+
self.cpu_group, tcp_store_group=tcp_store_group
139+
)
123140
logger.info("Using DeepEP Low-Latency all2all manager.")
124141
elif all2all_backend == "flashinfer_all2allv":
125142
from .all2all import FlashInferAllToAllManager
126143

127-
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group, tcp_store_group=tcp_store_group)
144+
self.all2all_manager = FlashInferAllToAllManager(
145+
self.cpu_group, tcp_store_group=tcp_store_group
146+
)
128147
logger.info("Using Flashinfer all2allv manager.")
129148
else:
130149
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

vllm/distributed/device_communicators/pynccl.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# ===================== import region =====================
77
import torch
88
import torch.distributed as dist
9-
from torch.distributed import ProcessGroup, ReduceOp, P2POp
9+
from torch.distributed import ProcessGroup, ReduceOp
1010

1111
import vllm.envs as envs
1212
from vllm.distributed.device_communicators.pynccl_wrapper import (
@@ -312,7 +312,12 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
312312
)
313313
if stream is None:
314314
stream = current_stream()
315-
if tensor.dtype in [torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz]:
315+
if tensor.dtype in [
316+
torch.float8_e5m2,
317+
torch.float8_e4m3fn,
318+
torch.float8_e4m3fnuz,
319+
torch.float8_e5m2fnuz,
320+
]:
316321
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
317322
else:
318323
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
@@ -334,7 +339,12 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
334339
)
335340
if stream is None:
336341
stream = current_stream()
337-
if tensor.dtype in [torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz]:
342+
if tensor.dtype in [
343+
torch.float8_e5m2,
344+
torch.float8_e4m3fn,
345+
torch.float8_e4m3fnuz,
346+
torch.float8_e5m2fnuz,
347+
]:
338348
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
339349
else:
340350
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)

0 commit comments

Comments
 (0)