Skip to content

Commit d411df0

Browse files
[Misc] Further refine type annotations in parallel state (#22499)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 010e0e3 commit d411df0

File tree

2 files changed

+19
-20
lines changed

2 files changed

+19
-20
lines changed

vllm/distributed/eplb/eplb_state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def build(
259259

260260
if global_expert_load is not None:
261261
ep_group = get_ep_group().device_group
262-
assert ep_group is not None
263262
assert global_expert_load.shape == (model.num_moe_layers,
264263
model.num_logical_experts)
265264
assert global_expert_load.dtype == torch.int64
@@ -366,7 +365,6 @@ def step(self,
366365

367366
# Collect load metrics from all ranks
368367
ep_group = get_ep_group().device_group
369-
assert ep_group is not None
370368
all_reduce(total_expert_load_pass, group=ep_group)
371369

372370
# num_tokens_per_rank: (num_moe_layers, num_ranks)
@@ -422,7 +420,6 @@ def rearrange(self,
422420
"""
423421

424422
ep_group = get_ep_group().device_group
425-
assert ep_group is not None
426423
ep_rank = ep_group.rank()
427424

428425
time_start = None

vllm/distributed/parallel_state.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,18 @@ class GroupCoordinator:
197197
# 3 | 1 | 3 | 1 | 3
198198
local_rank: int # local rank used to assign devices
199199
rank_in_group: int # rank inside the group
200-
cpu_group: Optional[ProcessGroup] # group for CPU communication
201-
device_group: Optional[ProcessGroup] # group for device communication
202-
use_device_communicator: bool # whether to use device communicator
203-
device_communicator: Optional[
204-
DeviceCommunicatorBase] # device communicator
200+
cpu_group: ProcessGroup # group for CPU communication
201+
device_group: ProcessGroup # group for device communication
202+
# device communicator (if use_device_communicator=True)
203+
device_communicator: Optional[DeviceCommunicatorBase]
205204
mq_broadcaster: Optional[Any] # shared memory broadcaster
206205

207206
def __init__(
208207
self,
209208
group_ranks: list[list[int]],
210209
local_rank: int,
211210
torch_distributed_backend: Union[str, Backend],
212-
use_device_communicator: bool,
211+
use_device_communicator: bool, # whether to use device communicator
213212
use_message_queue_broadcaster: bool = False,
214213
group_name: Optional[str] = None,
215214
):
@@ -219,8 +218,9 @@ def __init__(
219218

220219
self.rank = torch.distributed.get_rank()
221220
self.local_rank = local_rank
222-
self.device_group = None
223-
self.cpu_group = None
221+
222+
self_device_group = None
223+
self_cpu_group = None
224224

225225
for ranks in group_ranks:
226226
device_group = torch.distributed.new_group(
@@ -232,11 +232,14 @@ def __init__(
232232
self.ranks = ranks
233233
self.world_size = len(ranks)
234234
self.rank_in_group = ranks.index(self.rank)
235-
self.device_group = device_group
236-
self.cpu_group = cpu_group
235+
self_device_group = device_group
236+
self_cpu_group = cpu_group
237+
238+
assert self_cpu_group is not None
239+
assert self_device_group is not None
237240

238-
assert self.cpu_group is not None
239-
assert self.device_group is not None
241+
self.cpu_group = self_cpu_group
242+
self.device_group = self_device_group
240243

241244
from vllm.platforms import current_platform
242245

@@ -251,7 +254,6 @@ def __init__(
251254
self.device = torch.device("cpu")
252255

253256
self.use_device_communicator = use_device_communicator
254-
255257
self.device_communicator = None
256258
if use_device_communicator and self.world_size > 1:
257259
device_comm_cls = resolve_obj_by_qualname(
@@ -817,12 +819,12 @@ def recv(self,
817819
return self.device_communicator.recv(size, dtype, src)
818820

819821
def destroy(self):
820-
if self.device_group is not None:
822+
if hasattr(self, "device_group"):
821823
torch.distributed.destroy_process_group(self.device_group)
822-
self.device_group = None
823-
if self.cpu_group is not None:
824+
del self.device_group
825+
if hasattr(self, "cpu_group"):
824826
torch.distributed.destroy_process_group(self.cpu_group)
825-
self.cpu_group = None
827+
del self.cpu_group
826828
if self.device_communicator is not None:
827829
self.device_communicator.destroy()
828830
if self.mq_broadcaster is not None:

0 commit comments

Comments
 (0)