Skip to content

Commit 0812d8d

Browse files
[Hardware][Gaudi][BugFix] fix arguments of hpu fused moe (#15945)
Signed-off-by: zhenwei <[email protected]>
1 parent bf7e3c5 commit 0812d8d

File tree

1 file changed

+5
-2
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+5
-2
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,12 @@ def forward_hpu(
254254
renormalize: bool,
255255
topk_group: Optional[int] = None,
256256
num_expert_group: Optional[int] = None,
257+
global_num_experts: int = -1,
258+
expert_map: Optional[torch.Tensor] = None,
257259
custom_routing_function: Optional[Callable] = None,
258260
scoring_func: str = "softmax",
259-
e_score_correction_bias: Optional[torch.Tensor] = None
261+
e_score_correction_bias: Optional[torch.Tensor] = None,
262+
activation: str = "silu",
260263
) -> torch.Tensor:
261264
assert not use_grouped_topk
262265
assert num_expert_group is None
@@ -472,7 +475,7 @@ def __init__(
472475
"non-grouped topk.")
473476
if current_platform.is_hpu():
474477
from vllm_hpu_extension.ops import DynamicFusedMOE
475-
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
478+
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
476479

477480
# Note: get_quant_method will look at the layer's local_num_experts
478481
# for heuristic purposes, so it must be initialized first.

0 commit comments

Comments
 (0)