Skip to content

Commit 74333ae

Browse files
authored
[Misc] correct static type check for GroupCoordinator (#21946)
Signed-off-by: Andy Xie <[email protected]>
1 parent 83156c7 commit 74333ae

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

vllm/distributed/device_communicators/ray_communicator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
assert ray.get_gpu_ids(), "RayPPCommunicator has no GPUs assigned"
7171

7272
self._comm = get_pp_group().device_communicator
73+
assert self._comm is not None
7374

7475
# Since we wrap around the vLLM _PP communicator, we use
7576
# the rank from the vLLM communicator, and ignore the rank

vllm/distributed/eplb/eplb_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def build(
251251

252252
if global_expert_load is not None:
253253
ep_group = get_ep_group().device_group
254+
assert ep_group is not None
254255
assert global_expert_load.shape == (model.num_moe_layers,
255256
model.num_logical_experts)
256257
assert global_expert_load.dtype == torch.int64
@@ -357,6 +358,7 @@ def step(self,
357358

358359
# Collect load metrics from all ranks
359360
ep_group = get_ep_group().device_group
361+
assert ep_group is not None
360362
num_tokens_list = [
361363
torch.empty_like(num_tokens) for _ in range(ep_group.size())
362364
]
@@ -412,6 +414,7 @@ def rearrange(self,
412414
"""
413415

414416
ep_group = get_ep_group().device_group
417+
assert ep_group is not None
415418
ep_rank = ep_group.rank()
416419

417420
time_start = None

vllm/distributed/parallel_state.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,11 @@ class GroupCoordinator:
196196
# 3 | 1 | 3 | 1 | 3
197197
local_rank: int # local rank used to assign devices
198198
rank_in_group: int # rank inside the group
199-
cpu_group: ProcessGroup # group for CPU communication
200-
device_group: ProcessGroup # group for device communication
199+
cpu_group: Optional[ProcessGroup] # group for CPU communication
200+
device_group: Optional[ProcessGroup] # group for device communication
201201
use_device_communicator: bool # whether to use device communicator
202-
device_communicator: DeviceCommunicatorBase # device communicator
202+
device_communicator: Optional[
203+
DeviceCommunicatorBase] # device communicator
203204
mq_broadcaster: Optional[Any] # shared memory broadcaster
204205

205206
def __init__(
@@ -250,7 +251,7 @@ def __init__(
250251

251252
self.use_device_communicator = use_device_communicator
252253

253-
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
254+
self.device_communicator = None
254255
if use_device_communicator and self.world_size > 1:
255256
device_comm_cls = resolve_obj_by_qualname(
256257
current_platform.get_device_communicator_cls())
@@ -364,6 +365,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
364365
return self._all_reduce_out_place(input_)
365366

366367
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
368+
if self.device_communicator is None:
369+
raise ValueError("No device communicator found")
367370
return self.device_communicator.all_reduce(input_)
368371

369372
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
@@ -384,12 +387,16 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
384387

385388
def _all_gather_out_place(self, input_: torch.Tensor,
386389
dim: int) -> torch.Tensor:
390+
if self.device_communicator is None:
391+
raise ValueError("No device communicator found")
387392
return self.device_communicator.all_gather(input_, dim)
388393

389394
def all_gatherv(self,
390395
input_: Union[torch.Tensor, list[torch.Tensor]],
391396
dim: int = 0,
392397
sizes: Optional[list[int]] = None):
398+
if self.device_communicator is None:
399+
raise ValueError("No device communicator found")
393400
return self.device_communicator.all_gatherv(input_, dim, sizes)
394401

395402
def reduce_scatter(self,
@@ -414,10 +421,14 @@ def reduce_scatterv(self,
414421
input_: torch.Tensor,
415422
dim: int = -1,
416423
sizes: Optional[list[int]] = None) -> torch.Tensor:
424+
if self.device_communicator is None:
425+
raise ValueError("No device communicator found")
417426
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
418427

419428
def _reduce_scatter_out_place(self, input_: torch.Tensor,
420429
dim: int) -> torch.Tensor:
430+
if self.device_communicator is None:
431+
raise ValueError("No device communicator found")
421432
return self.device_communicator.reduce_scatter(input_, dim)
422433

423434
def gather(self,
@@ -433,6 +444,8 @@ def gather(self,
433444
# Bypass the function if we are using only 1 GPU.
434445
if world_size == 1:
435446
return input_
447+
if self.device_communicator is None:
448+
raise ValueError("No device communicator found")
436449
return self.device_communicator.gather(input_, dst, dim)
437450

438451
def broadcast(self, input_: torch.Tensor, src: int = 0):
@@ -667,6 +680,8 @@ def send_tensor_dict(
667680
assert dst < self.world_size, f"Invalid dst rank ({dst})"
668681

669682
if self.use_cpu_custom_send_recv:
683+
if self.device_communicator is None:
684+
raise ValueError("No device communicator found")
670685
self.device_communicator.send_tensor_dict( # type: ignore
671686
tensor_dict, dst)
672687
return None
@@ -727,6 +742,8 @@ def recv_tensor_dict(
727742
assert src < self.world_size, f"Invalid src rank ({src})"
728743

729744
if self.use_cpu_custom_send_recv:
745+
if self.device_communicator is None:
746+
raise ValueError("No device communicator found")
730747
return self.device_communicator.recv_tensor_dict( # type: ignore
731748
src)
732749

@@ -784,6 +801,8 @@ def barrier(self):
784801
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
785802
"""Sends a tensor to the destination rank in a blocking way"""
786803
"""NOTE: `dst` is the local rank of the destination rank."""
804+
if self.device_communicator is None:
805+
raise ValueError("No device communicator found")
787806
self.device_communicator.send(tensor, dst)
788807

789808
def recv(self,
@@ -792,6 +811,8 @@ def recv(self,
792811
src: Optional[int] = None) -> torch.Tensor:
793812
"""Receives a tensor from the source rank."""
794813
"""NOTE: `src` is the local rank of the source rank."""
814+
if self.device_communicator is None:
815+
raise ValueError("No device communicator found")
795816
return self.device_communicator.recv(size, dtype, src)
796817

797818
def destroy(self):

0 commit comments

Comments
 (0)