Skip to content

Commit 4b79502

Browse files
22quinnsimon-mo
andauthored
[EP] Add logging for experts map (#22685)
Signed-off-by: 22quinn <[email protected]> Co-authored-by: Simon Mo <[email protected]>
1 parent c86af22 commit 4b79502

File tree

1 file changed

+26
-0
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+26
-0
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,26 @@ def determine_expert_map(
695695
return (local_num_experts, expert_map)
696696

697697

698+
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
699+
"""
700+
Compresses the expert map by removing any -1 entries.
701+
702+
Args:
703+
expert_map (torch.Tensor): A tensor of shape (global_num_experts,)
704+
mapping from global to local index. Contains -1 for experts not
705+
assigned to the current rank.
706+
707+
Returns:
708+
str: A string mapping from local to global index.
709+
Using str to support hashing for logging once only.
710+
"""
711+
global_indices = torch.where(expert_map != -1)[0]
712+
local_indices = expert_map[global_indices]
713+
return ", ".join(
714+
f"{local_index.item()}->{global_index.item()}"
715+
for local_index, global_index in zip(local_indices, global_indices))
716+
717+
698718
@CustomOp.register("fused_moe")
699719
class FusedMoE(CustomOp):
700720
"""FusedMoE layer for MoE models.
@@ -795,6 +815,12 @@ def __init__(
795815
ep_size=self.ep_size,
796816
ep_rank=self.ep_rank,
797817
global_num_experts=self.global_num_experts)
818+
logger.info_once(
819+
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
820+
" number of experts: %s/%s. Experts local to global index map:"
821+
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
822+
self.global_num_experts,
823+
get_compressed_expert_map(self.expert_map))
798824
else:
799825
self.local_num_experts, self.expert_map = (self.global_num_experts,
800826
None)

0 commit comments

Comments
 (0)