Skip to content

Commit 4be02a3

Browse files
[Bugfix] EPLB load statistics problem (#22167)
Signed-off-by: ycyaw66 <[email protected]> Signed-off-by: David Chen <[email protected]> Co-authored-by: ycyaw66 <[email protected]>
1 parent f6278b6 commit 4be02a3

File tree

2 files changed

+26
-41
lines changed

2 files changed

+26
-41
lines changed

vllm/distributed/eplb/eplb_state.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from typing import Optional, Union
3333

3434
import torch
35-
from torch.distributed import ProcessGroup, all_gather, all_reduce
35+
from torch.distributed import ProcessGroup, all_reduce
3636

3737
from vllm.config import ParallelConfig
3838
from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
@@ -112,13 +112,21 @@ class EplbState:
112112
Expert load during this forward pass.
113113
We use the token count each expert processes as the load.
114114
115-
Shape: (num_moe_layers, num_local_physical_experts)
115+
Shape: (num_moe_layers, num_physical_experts)
116116
"""
117117
expert_load_window: torch.Tensor
118118
"""
119119
A sliding window of expert load.
120120
121-
Shape: (window_size, num_moe_layers, num_local_physical_experts)
121+
Shape: (window_size, num_moe_layers, num_physical_experts)
122+
123+
NOTE: The expert_load_view now records load for all physical experts
124+
rather than just local experts. This ensures consistent load statistics
125+
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
126+
The recorded load will be multiplied by dp_size when using naive all-to-all
127+
due to each DP rank contributing the same token set to the calculation.
128+
See:
129+
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
122130
"""
123131
expert_load_window_step: int = 0
124132
"""
@@ -232,14 +240,14 @@ def build(
232240
).contiguous()
233241

234242
expert_load_pass = torch.zeros(
235-
(model.num_moe_layers, model.num_local_physical_experts),
243+
(model.num_moe_layers, model.num_physical_experts),
236244
dtype=torch.int32,
237245
device=device,
238246
)
239247
expert_load_window_size = parallel_config.eplb_window_size
240248
expert_load_window = torch.zeros(
241249
(expert_load_window_size, model.num_moe_layers,
242-
model.num_local_physical_experts),
250+
model.num_physical_experts),
243251
dtype=torch.int32,
244252
device=device,
245253
)
@@ -353,18 +361,18 @@ def step(self,
353361
self.expert_load_pass.zero_()
354362

355363
if log_stats:
356-
# `num_tokens`: (num_moe_layers,)
357-
num_tokens = self.expert_load_pass.sum(dim=-1)
364+
# total_expert_load_pass: (num_moe_layers, num_physical_experts)
365+
total_expert_load_pass = self.expert_load_pass.clone()
358366

359367
# Collect load metrics from all ranks
360368
ep_group = get_ep_group().device_group
361369
assert ep_group is not None
362-
num_tokens_list = [
363-
torch.empty_like(num_tokens) for _ in range(ep_group.size())
364-
]
365-
all_gather(num_tokens_list, num_tokens, group=ep_group)
366-
# Stack to get (num_ranks, num_moe_layers)
367-
num_tokens_per_rank = torch.stack(num_tokens_list).float()
370+
all_reduce(total_expert_load_pass, group=ep_group)
371+
372+
# num_tokens_per_rank: (num_moe_layers, num_ranks)
373+
num_tokens_per_rank = total_expert_load_pass.reshape(
374+
total_expert_load_pass.shape[0], ep_group.size(),
375+
-1).sum(dim=-1).float()
368376

369377
# Compute balancedness ratio:
370378
# for each layer:
@@ -426,17 +434,7 @@ def rearrange(self,
426434
"(profile)" if is_profile else "")
427435

428436
if global_expert_load is None:
429-
# This mapping is only used here, so we do not store it in the state
430-
physical_expert_start = ep_rank * model.num_local_physical_experts
431-
physical_expert_end = (physical_expert_start +
432-
model.num_local_physical_experts)
433-
# (num_moe_layers, num_local_physical_experts)
434-
local_physical_to_logical_map = self.physical_to_logical_map[
435-
:,
436-
physical_expert_start:physical_expert_end,
437-
]
438-
439-
# Map the local physical expert load to global logical experts
437+
# Map the physical expert load to global logical experts
440438
logical_expert_load_window = torch.zeros(
441439
self.expert_load_window_size,
442440
model.num_moe_layers,
@@ -446,7 +444,7 @@ def rearrange(self,
446444
)
447445
logical_expert_load_window.scatter_add_(
448446
dim=-1,
449-
index=local_physical_to_logical_map.unsqueeze(0).expand_as(
447+
index=self.physical_to_logical_map.unsqueeze(0).expand_as(
450448
self.expert_load_window).long(),
451449
src=self.expert_load_window,
452450
)
@@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
618616
if is_same_node and node_assignment[other_rank] == 0:
619617
node_assignment[other_rank] = next_node_id
620618

621-
return next_node_id
619+
return next_node_id

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,22 +1430,9 @@ def select_experts(
14301430
# to the modular kernel, we can move this logic there
14311431
# to achieve better efficiency.
14321432

1433-
# `expert_load_view`: (num_logical_experts,)
1433+
# `expert_load_view`: (num_physical_experts,)
14341434

1435-
# Mask out non-local experts
1436-
if expert_map is not None:
1437-
topk_ids_local = expert_map[topk_ids]
1438-
topk_ids_flatten = topk_ids_local.flatten()
1439-
else:
1440-
topk_ids_flatten = topk_ids.flatten()
1441-
1442-
# Should be equivalent to:
1443-
# ```
1444-
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
1445-
# expert_load_view += topk_ids_masked.bincount(
1446-
# minlength=expert_load_view.shape[0])
1447-
# ```
1448-
# We use `scatter_add_` since `bincount` cannot be compiled
1435+
topk_ids_flatten = topk_ids.flatten()
14491436

14501437
# Performance optimization:
14511438
# `masked_fill` is significantly faster than `masked_select`

0 commit comments

Comments
 (0)