@@ -203,13 +203,14 @@ def apply(
203
203
x : torch .Tensor ,
204
204
router_logits : torch .Tensor ,
205
205
top_k : int ,
206
- renormalize : bool = True ,
206
+ renormalize : bool ,
207
207
use_grouped_topk : bool = False ,
208
- num_expert_group : Optional [int ] = None ,
209
208
topk_group : Optional [int ] = None ,
209
+ num_expert_group : Optional [int ] = None ,
210
210
custom_routing_function : Optional [Callable ] = None ,
211
+ scoring_func : str = "softmax" ,
212
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
211
213
) -> torch .Tensor :
212
-
213
214
from vllm .model_executor .layers .fused_moe import fused_experts
214
215
215
216
topk_weights , topk_ids = FusedMoE .select_experts (
@@ -220,7 +221,9 @@ def apply(
220
221
renormalize = renormalize ,
221
222
topk_group = topk_group ,
222
223
num_expert_group = num_expert_group ,
223
- custom_routing_function = custom_routing_function )
224
+ custom_routing_function = custom_routing_function ,
225
+ scoring_func = scoring_func ,
226
+ e_score_correction_bias = e_score_correction_bias )
224
227
225
228
return fused_experts (x ,
226
229
layer .w13_weight ,
@@ -476,12 +479,15 @@ def apply(
476
479
x : torch .Tensor ,
477
480
router_logits : torch .Tensor ,
478
481
top_k : int ,
479
- renormalize : bool = True ,
482
+ renormalize : bool ,
480
483
use_grouped_topk : bool = False ,
481
- num_expert_group : Optional [int ] = None ,
482
484
topk_group : Optional [int ] = None ,
485
+ num_expert_group : Optional [int ] = None ,
483
486
custom_routing_function : Optional [Callable ] = None ,
487
+ scoring_func : str = "softmax" ,
488
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
484
489
) -> torch .Tensor :
490
+
485
491
topk_weights , topk_ids = FusedMoE .select_experts (
486
492
hidden_states = x ,
487
493
router_logits = router_logits ,
@@ -490,7 +496,9 @@ def apply(
490
496
renormalize = renormalize ,
491
497
topk_group = topk_group ,
492
498
num_expert_group = num_expert_group ,
493
- custom_routing_function = custom_routing_function )
499
+ custom_routing_function = custom_routing_function ,
500
+ scoring_func = scoring_func ,
501
+ e_score_correction_bias = e_score_correction_bias )
494
502
495
503
return torch .ops .vllm .fused_marlin_moe (
496
504
x ,
0 commit comments