32
32
from typing import Optional , Union
33
33
34
34
import torch
35
- from torch .distributed import ProcessGroup , all_gather , all_reduce
35
+ from torch .distributed import ProcessGroup , all_reduce
36
36
37
37
from vllm .config import ParallelConfig
38
38
from vllm .distributed .parallel_state import (get_ep_group , get_node_count ,
@@ -112,13 +112,21 @@ class EplbState:
112
112
Expert load during this forward pass.
113
113
We use the token count each expert processes as the load.
114
114
115
- Shape: (num_moe_layers, num_local_physical_experts )
115
+ Shape: (num_moe_layers, num_physical_experts )
116
116
"""
117
117
expert_load_window : torch .Tensor
118
118
"""
119
119
A sliding window of expert load.
120
120
121
- Shape: (window_size, num_moe_layers, num_local_physical_experts)
121
+ Shape: (window_size, num_moe_layers, num_physical_experts)
122
+
123
+ NOTE: The expert_load_view now records load for all physical experts
124
+ rather than just local experts. This ensures consistent load statistics
125
+ across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
126
+ The recorded load will be multiplied by dp_size when using naive all-to-all
127
+ due to each DP rank contributing the same token set to the calculation.
128
+ See:
129
+ https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
122
130
"""
123
131
expert_load_window_step : int = 0
124
132
"""
@@ -232,14 +240,14 @@ def build(
232
240
).contiguous ()
233
241
234
242
expert_load_pass = torch .zeros (
235
- (model .num_moe_layers , model .num_local_physical_experts ),
243
+ (model .num_moe_layers , model .num_physical_experts ),
236
244
dtype = torch .int32 ,
237
245
device = device ,
238
246
)
239
247
expert_load_window_size = parallel_config .eplb_window_size
240
248
expert_load_window = torch .zeros (
241
249
(expert_load_window_size , model .num_moe_layers ,
242
- model .num_local_physical_experts ),
250
+ model .num_physical_experts ),
243
251
dtype = torch .int32 ,
244
252
device = device ,
245
253
)
@@ -353,18 +361,18 @@ def step(self,
353
361
self .expert_load_pass .zero_ ()
354
362
355
363
if log_stats :
356
- # `num_tokens` : (num_moe_layers,)
357
- num_tokens = self .expert_load_pass .sum ( dim = - 1 )
364
+ # total_expert_load_pass : (num_moe_layers, num_physical_experts )
365
+ total_expert_load_pass = self .expert_load_pass .clone ( )
358
366
359
367
# Collect load metrics from all ranks
360
368
ep_group = get_ep_group ().device_group
361
369
assert ep_group is not None
362
- num_tokens_list = [
363
- torch . empty_like ( num_tokens ) for _ in range ( ep_group . size ())
364
- ]
365
- all_gather ( num_tokens_list , num_tokens , group = ep_group )
366
- # Stack to get (num_ranks, num_moe_layers)
367
- num_tokens_per_rank = torch . stack ( num_tokens_list ).float ()
370
+ all_reduce ( total_expert_load_pass , group = ep_group )
371
+
372
+ # num_tokens_per_rank: (num_moe_layers, num_ranks)
373
+ num_tokens_per_rank = total_expert_load_pass . reshape (
374
+ total_expert_load_pass . shape [ 0 ], ep_group . size (),
375
+ - 1 ). sum ( dim = - 1 ).float ()
368
376
369
377
# Compute balancedness ratio:
370
378
# for each layer:
@@ -426,17 +434,7 @@ def rearrange(self,
426
434
"(profile)" if is_profile else "" )
427
435
428
436
if global_expert_load is None :
429
- # This mapping is only used here, so we do not store it in the state
430
- physical_expert_start = ep_rank * model .num_local_physical_experts
431
- physical_expert_end = (physical_expert_start +
432
- model .num_local_physical_experts )
433
- # (num_moe_layers, num_local_physical_experts)
434
- local_physical_to_logical_map = self .physical_to_logical_map [
435
- :,
436
- physical_expert_start :physical_expert_end ,
437
- ]
438
-
439
- # Map the local physical expert load to global logical experts
437
+ # Map the physical expert load to global logical experts
440
438
logical_expert_load_window = torch .zeros (
441
439
self .expert_load_window_size ,
442
440
model .num_moe_layers ,
@@ -446,7 +444,7 @@ def rearrange(self,
446
444
)
447
445
logical_expert_load_window .scatter_add_ (
448
446
dim = - 1 ,
449
- index = local_physical_to_logical_map .unsqueeze (0 ).expand_as (
447
+ index = self . physical_to_logical_map .unsqueeze (0 ).expand_as (
450
448
self .expert_load_window ).long (),
451
449
src = self .expert_load_window ,
452
450
)
@@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
618
616
if is_same_node and node_assignment [other_rank ] == 0 :
619
617
node_assignment [other_rank ] = next_node_id
620
618
621
- return next_node_id
619
+ return next_node_id
0 commit comments