|
43 | 43 |
|
44 | 44 | if current_platform.is_cuda_alike():
|
45 | 45 | from .fused_batched_moe import BatchedTritonExperts
|
46 |
| - from .fused_moe import TritonExperts, fused_experts |
| 46 | + from .fused_moe import (TritonExperts, eplb_map_to_physical_and_record, |
| 47 | + fused_experts) |
47 | 48 | if has_pplx():
|
48 | 49 | from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
|
49 | 50 | pplx_hidden_dim_scale_bytes)
|
|
55 | 56 | fused_experts = None # type: ignore
|
56 | 57 | FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
57 | 58 | FusedMoEPrepareAndFinalize = None # type: ignore
|
| 59 | + |
| 60 | + def eplb_map_to_physical_and_record( |
| 61 | + topk_ids: torch.Tensor, expert_load_view: torch.Tensor, |
| 62 | + logical_to_physical_map: torch.Tensor, |
| 63 | + logical_replica_count: torch.Tensor, |
| 64 | + indices_type: Optional[torch.dtype]) -> torch.Tensor: |
| 65 | + # CPU fallback: no EPLB so just return as is |
| 66 | + return topk_ids |
| 67 | + |
| 68 | + |
58 | 69 | if is_rocm_aiter_moe_enabled():
|
59 | 70 | from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
60 | 71 | rocm_aiter_grouped_topk as grouped_topk)
|
@@ -1616,55 +1627,13 @@ def select_experts(
|
1616 | 1627 | assert logical_to_physical_map is not None
|
1617 | 1628 | assert logical_replica_count is not None
|
1618 | 1629 |
|
1619 |
| - # 1. Convert the logical expert ids to physical expert ids |
1620 |
| - # Directly select a random replica for each logical expert |
1621 |
| - |
1622 |
| - # TODO: maybe optimize this by using specified kernels, |
1623 |
| - # or compute pseudo-random indices by modulo |
1624 |
| - |
1625 |
| - # In case `indices_type` is not `torch.long` or `torch.int`, |
1626 |
| - # e.g. `torch.uint32` as required by dispatch/combine kernels |
1627 |
| - topk_ids_long = topk_ids.long() |
1628 |
| - replica_indices = ( |
1629 |
| - torch.rand_like(topk_ids, dtype=torch.float) * |
1630 |
| - logical_replica_count[topk_ids_long]).long().unsqueeze(-1) |
1631 |
| - physical_ids = logical_to_physical_map[topk_ids_long].gather( |
1632 |
| - -1, replica_indices).squeeze(-1) |
1633 |
| - |
1634 |
| - topk_ids = physical_ids |
1635 |
| - |
1636 |
| - # 2. Record expert load metrics. |
1637 |
| - |
1638 |
| - # TODO(bowen): When using `FusedMoEModularKernel`, this |
1639 |
| - # can be done in a more unified way, since |
1640 |
| - # `FusedMoEPrepareAndFinalize` will return the expert |
1641 |
| - # token count, in some cases directly from the kernel. |
1642 |
| - # However, now there are many code paths not using |
1643 |
| - # the modular kernel, e.g. calling `fused_experts`, |
1644 |
| - # so we decide to keep the logic here. |
1645 |
| - # |
1646 |
| - # If later refactor moved all the MoE kernel calls |
1647 |
| - # to the modular kernel, we can move this logic there |
1648 |
| - # to achieve better efficiency. |
1649 |
| - |
1650 |
| - # `expert_load_view`: (num_physical_experts,) |
1651 |
| - |
1652 |
| - topk_ids_flatten = topk_ids.flatten() |
1653 |
| - |
1654 |
| - # Performance optimization: |
1655 |
| - # `masked_fill` is significantly faster than `masked_select` |
1656 |
| - invalid_mask = topk_ids_flatten < 0 |
1657 |
| - # Replace invalid expert ids with 0 (just a dummy position) |
1658 |
| - # to avoid out-of-bounds errors in scatter_add_ |
1659 |
| - index = topk_ids_flatten.masked_fill_(invalid_mask, 0) |
1660 |
| - # `src` is the valid mask, which is 1 for valid and 0 for invalid |
1661 |
| - src = ~invalid_mask |
1662 |
| - |
1663 |
| - expert_load_view.scatter_add_(dim=0, |
1664 |
| - index=index.long(), |
1665 |
| - src=src.to(expert_load_view)) |
1666 |
| - |
1667 |
| - topk_ids = topk_ids.to(dtype=indices_type) |
| 1630 | + topk_ids = eplb_map_to_physical_and_record( |
| 1631 | + topk_ids=topk_ids, |
| 1632 | + expert_load_view=expert_load_view, |
| 1633 | + logical_to_physical_map=logical_to_physical_map, |
| 1634 | + logical_replica_count=logical_replica_count, |
| 1635 | + indices_type=indices_type, |
| 1636 | + ) |
1668 | 1637 |
|
1669 | 1638 | assert topk_ids.dtype == indices_type or indices_type is None
|
1670 | 1639 |
|
|
0 commit comments