Skip to content

Commit 06a4133

Browse files
abmfygemini-code-assist[bot]tlrmchlsmth
authored
[EPLB] Reduce EPLB Inference Overhead (#24573)
Signed-off-by: Bowen Wang <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent 175811e commit 06a4133

File tree

2 files changed

+92
-50
lines changed

2 files changed

+92
-50
lines changed

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,79 @@ def grouped_topk(
10171017
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
10181018

10191019

1020+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
1021+
def eplb_map_to_physical_and_record(
1022+
topk_ids: torch.Tensor,
1023+
expert_load_view: torch.Tensor,
1024+
logical_to_physical_map: torch.Tensor,
1025+
logical_replica_count: torch.Tensor,
1026+
indices_type: Optional[torch.dtype] = None) -> torch.Tensor:
1027+
'''
1028+
Map the logical expert ids to physical expert ids
1029+
and record the expert load metrics.
1030+
1031+
This will select a pseudo-random replica for each logical expert.
1032+
Only used for EPLB.
1033+
1034+
Args:
1035+
topk_ids: The logical expert ids.
1036+
expert_load_view: The expert load view.
1037+
logical_to_physical_map: The logical to physical map.
1038+
logical_replica_count: The logical replica count.
1039+
indices_type: The indices type.
1040+
1041+
Returns:
1042+
The physical expert ids.
1043+
'''
1044+
1045+
# 1. Convert the logical expert ids to physical expert ids
1046+
# Directly select a random replica for each logical expert
1047+
1048+
# In case `indices_type` is not `torch.long` or `torch.int`,
1049+
# e.g. `torch.uint32` as required by dispatch/combine kernels
1050+
topk_ids_long = topk_ids.long()
1051+
# Use (token position) modulo (replica count)
1052+
# to deterministically choose a replica
1053+
replica_count = logical_replica_count[topk_ids_long]
1054+
# Flatten-position based index, reshaped back to `topk_ids` shape
1055+
pos_indices = torch.arange(topk_ids.numel(),
1056+
device=topk_ids.device,
1057+
dtype=torch.long).reshape_as(topk_ids)
1058+
# Compute pseudo-random indices by modulo
1059+
replica_indices = (pos_indices % replica_count).unsqueeze(-1)
1060+
physical_ids = logical_to_physical_map[topk_ids_long].gather(
1061+
-1, replica_indices).squeeze(-1)
1062+
1063+
topk_ids = physical_ids
1064+
1065+
# 2. Record expert load metrics.
1066+
1067+
# TODO(bowen): When using `FusedMoEModularKernel`, this
1068+
# can be done in a more unified way, since
1069+
# `FusedMoEPrepareAndFinalize` will return the expert
1070+
# token count, in some cases directly from the kernel.
1071+
# However, now there are many code paths not using
1072+
# the modular kernel, e.g. calling `fused_experts`,
1073+
# so we decide to keep the logic here.
1074+
#
1075+
# If later refactor moved all the MoE kernel calls
1076+
# to the modular kernel, we can move this logic there
1077+
# to achieve better efficiency.
1078+
1079+
# `expert_load_view`: (num_physical_experts,)
1080+
1081+
# `torch.bincount` is not compilable, so use `scatter_add_` instead.
1082+
topk_ids_flatten = topk_ids.flatten()
1083+
expert_load_view.scatter_add_(
1084+
dim=0,
1085+
index=topk_ids_flatten.long(),
1086+
src=torch.ones_like(topk_ids_flatten).to(expert_load_view))
1087+
1088+
if indices_type is not None:
1089+
topk_ids = topk_ids.to(dtype=indices_type)
1090+
return topk_ids
1091+
1092+
10201093
def fused_grouped_topk(
10211094
hidden_states: torch.Tensor,
10221095
gating_output: torch.Tensor,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 19 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343

4444
if current_platform.is_cuda_alike():
4545
from .fused_batched_moe import BatchedTritonExperts
46-
from .fused_moe import TritonExperts, fused_experts
46+
from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record,
47+
fused_experts)
4748
if has_pplx():
4849
from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
4950
pplx_hidden_dim_scale_bytes)
@@ -55,6 +56,16 @@
5556
fused_experts = None # type: ignore
5657
FusedMoEPermuteExpertsUnpermute = None # type: ignore
5758
FusedMoEPrepareAndFinalize = None # type: ignore
59+
60+
def eplb_map_to_physical_and_record(
61+
topk_ids: torch.Tensor, expert_load_view: torch.Tensor,
62+
logical_to_physical_map: torch.Tensor,
63+
logical_replica_count: torch.Tensor,
64+
indices_type: Optional[torch.dtype]) -> torch.Tensor:
65+
# CPU fallback: no EPLB so just return as is
66+
return topk_ids
67+
68+
5869
if is_rocm_aiter_moe_enabled():
5970
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
6071
rocm_aiter_grouped_topk as grouped_topk)
@@ -1616,55 +1627,13 @@ def select_experts(
16161627
assert logical_to_physical_map is not None
16171628
assert logical_replica_count is not None
16181629

1619-
# 1. Convert the logical expert ids to physical expert ids
1620-
# Directly select a random replica for each logical expert
1621-
1622-
# TODO: maybe optimize this by using specified kernels,
1623-
# or compute pseudo-random indices by modulo
1624-
1625-
# In case `indices_type` is not `torch.long` or `torch.int`,
1626-
# e.g. `torch.uint32` as required by dispatch/combine kernels
1627-
topk_ids_long = topk_ids.long()
1628-
replica_indices = (
1629-
torch.rand_like(topk_ids, dtype=torch.float) *
1630-
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
1631-
physical_ids = logical_to_physical_map[topk_ids_long].gather(
1632-
-1, replica_indices).squeeze(-1)
1633-
1634-
topk_ids = physical_ids
1635-
1636-
# 2. Record expert load metrics.
1637-
1638-
# TODO(bowen): When using `FusedMoEModularKernel`, this
1639-
# can be done in a more unified way, since
1640-
# `FusedMoEPrepareAndFinalize` will return the expert
1641-
# token count, in some cases directly from the kernel.
1642-
# However, now there are many code paths not using
1643-
# the modular kernel, e.g. calling `fused_experts`,
1644-
# so we decide to keep the logic here.
1645-
#
1646-
# If later refactor moved all the MoE kernel calls
1647-
# to the modular kernel, we can move this logic there
1648-
# to achieve better efficiency.
1649-
1650-
# `expert_load_view`: (num_physical_experts,)
1651-
1652-
topk_ids_flatten = topk_ids.flatten()
1653-
1654-
# Performance optimization:
1655-
# `masked_fill` is significantly faster than `masked_select`
1656-
invalid_mask = topk_ids_flatten < 0
1657-
# Replace invalid expert ids with 0 (just a dummy position)
1658-
# to avoid out-of-bounds errors in scatter_add_
1659-
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
1660-
# `src` is the valid mask, which is 1 for valid and 0 for invalid
1661-
src = ~invalid_mask
1662-
1663-
expert_load_view.scatter_add_(dim=0,
1664-
index=index.long(),
1665-
src=src.to(expert_load_view))
1666-
1667-
topk_ids = topk_ids.to(dtype=indices_type)
1630+
topk_ids = eplb_map_to_physical_and_record(
1631+
topk_ids=topk_ids,
1632+
expert_load_view=expert_load_view,
1633+
logical_to_physical_map=logical_to_physical_map,
1634+
logical_replica_count=logical_replica_count,
1635+
indices_type=indices_type,
1636+
)
16681637

16691638
assert topk_ids.dtype == indices_type or indices_type is None
16701639

0 commit comments

Comments
 (0)