Skip to content

Commit 2339d59

Browse files
[BugFix] Fix quantization for all other methods (#11547)
1 parent 1b875a0 commit 2339d59

File tree

6 files changed

+52
-22
lines changed

6 files changed

+52
-22
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
4141
raise NotImplementedError
4242

4343
@abstractmethod
44-
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
45-
router_logits: torch.Tensor, top_k: int, renormalize: bool,
46-
use_grouped_topk: bool) -> torch.Tensor:
44+
def apply(
45+
self,
46+
layer: torch.nn.Module,
47+
x: torch.Tensor,
48+
router_logits: torch.Tensor,
49+
top_k: int,
50+
renormalize: bool,
51+
use_grouped_topk: bool = False,
52+
topk_group: Optional[int] = None,
53+
num_expert_group: Optional[int] = None,
54+
custom_routing_function: Optional[Callable] = None,
55+
scoring_func: str = "softmax",
56+
e_score_correction_bias: Optional[torch.Tensor] = None
57+
) -> torch.Tensor:
4758
raise NotImplementedError
4859

4960

@@ -79,7 +90,7 @@ def apply(
7990
router_logits: torch.Tensor,
8091
top_k: int,
8192
renormalize: bool,
82-
use_grouped_topk: bool,
93+
use_grouped_topk: bool = False,
8394
topk_group: Optional[int] = None,
8495
num_expert_group: Optional[int] = None,
8596
custom_routing_function: Optional[Callable] = None,

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,13 @@ def apply(
440440
x: torch.Tensor,
441441
router_logits: torch.Tensor,
442442
top_k: int,
443-
renormalize: bool = True,
443+
renormalize: bool,
444444
use_grouped_topk: bool = False,
445-
num_expert_group: Optional[int] = None,
446445
topk_group: Optional[int] = None,
446+
num_expert_group: Optional[int] = None,
447447
custom_routing_function: Optional[Callable] = None,
448+
scoring_func: str = "softmax",
449+
e_score_correction_bias: Optional[torch.Tensor] = None,
448450
) -> torch.Tensor:
449451
topk_weights, topk_ids = FusedMoE.select_experts(
450452
hidden_states=x,
@@ -454,7 +456,9 @@ def apply(
454456
renormalize=renormalize,
455457
topk_group=topk_group,
456458
num_expert_group=num_expert_group,
457-
custom_routing_function=custom_routing_function)
459+
custom_routing_function=custom_routing_function,
460+
scoring_func=scoring_func,
461+
e_score_correction_bias=e_score_correction_bias)
458462

459463
return torch.ops.vllm.fused_marlin_moe(
460464
x,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,14 @@ def apply(
203203
x: torch.Tensor,
204204
router_logits: torch.Tensor,
205205
top_k: int,
206-
renormalize: bool = True,
206+
renormalize: bool,
207207
use_grouped_topk: bool = False,
208-
num_expert_group: Optional[int] = None,
209208
topk_group: Optional[int] = None,
209+
num_expert_group: Optional[int] = None,
210210
custom_routing_function: Optional[Callable] = None,
211+
scoring_func: str = "softmax",
212+
e_score_correction_bias: Optional[torch.Tensor] = None,
211213
) -> torch.Tensor:
212-
213214
from vllm.model_executor.layers.fused_moe import fused_experts
214215

215216
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -220,7 +221,9 @@ def apply(
220221
renormalize=renormalize,
221222
topk_group=topk_group,
222223
num_expert_group=num_expert_group,
223-
custom_routing_function=custom_routing_function)
224+
custom_routing_function=custom_routing_function,
225+
scoring_func=scoring_func,
226+
e_score_correction_bias=e_score_correction_bias)
224227

225228
return fused_experts(x,
226229
layer.w13_weight,
@@ -476,12 +479,15 @@ def apply(
476479
x: torch.Tensor,
477480
router_logits: torch.Tensor,
478481
top_k: int,
479-
renormalize: bool = True,
482+
renormalize: bool,
480483
use_grouped_topk: bool = False,
481-
num_expert_group: Optional[int] = None,
482484
topk_group: Optional[int] = None,
485+
num_expert_group: Optional[int] = None,
483486
custom_routing_function: Optional[Callable] = None,
487+
scoring_func: str = "softmax",
488+
e_score_correction_bias: Optional[torch.Tensor] = None,
484489
) -> torch.Tensor:
490+
485491
topk_weights, topk_ids = FusedMoE.select_experts(
486492
hidden_states=x,
487493
router_logits=router_logits,
@@ -490,7 +496,9 @@ def apply(
490496
renormalize=renormalize,
491497
topk_group=topk_group,
492498
num_expert_group=num_expert_group,
493-
custom_routing_function=custom_routing_function)
499+
custom_routing_function=custom_routing_function,
500+
scoring_func=scoring_func,
501+
e_score_correction_bias=e_score_correction_bias)
494502

495503
return torch.ops.vllm.fused_marlin_moe(
496504
x,

vllm/model_executor/layers/quantization/experts_int8.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,13 @@ def apply(
9999
x: torch.Tensor,
100100
router_logits: torch.Tensor,
101101
top_k: int,
102-
renormalize: bool = True,
102+
renormalize: bool,
103103
use_grouped_topk: bool = False,
104-
num_expert_group: Optional[int] = None,
105104
topk_group: Optional[int] = None,
105+
num_expert_group: Optional[int] = None,
106106
custom_routing_function: Optional[Callable] = None,
107+
scoring_func: str = "softmax",
108+
e_score_correction_bias: Optional[torch.Tensor] = None,
107109
) -> torch.Tensor:
108110
from vllm.model_executor.layers.fused_moe import fused_experts
109111

@@ -115,7 +117,9 @@ def apply(
115117
renormalize=renormalize,
116118
topk_group=topk_group,
117119
num_expert_group=num_expert_group,
118-
custom_routing_function=custom_routing_function)
120+
custom_routing_function=custom_routing_function,
121+
scoring_func=scoring_func,
122+
e_score_correction_bias=e_score_correction_bias)
119123

120124
return fused_experts(x,
121125
layer.w13_weight,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,14 +601,13 @@ def apply(
601601
router_logits: torch.Tensor,
602602
top_k: int,
603603
renormalize: bool,
604-
use_grouped_topk: bool,
604+
use_grouped_topk: bool = False,
605605
topk_group: Optional[int] = None,
606606
num_expert_group: Optional[int] = None,
607607
custom_routing_function: Optional[Callable] = None,
608608
scoring_func: str = "softmax",
609609
e_score_correction_bias: Optional[torch.Tensor] = None,
610610
) -> torch.Tensor:
611-
612611
from vllm.model_executor.layers.fused_moe import fused_experts
613612

614613
topk_weights, topk_ids = FusedMoE.select_experts(

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,13 @@ def apply(
532532
x: torch.Tensor,
533533
router_logits: torch.Tensor,
534534
top_k: int,
535-
renormalize: bool = True,
535+
renormalize: bool,
536536
use_grouped_topk: bool = False,
537-
num_expert_group: Optional[int] = None,
538537
topk_group: Optional[int] = None,
538+
num_expert_group: Optional[int] = None,
539539
custom_routing_function: Optional[Callable] = None,
540+
scoring_func: str = "softmax",
541+
e_score_correction_bias: Optional[torch.Tensor] = None,
540542
) -> torch.Tensor:
541543
# The input must currently be float16
542544
orig_dtype = x.dtype
@@ -550,7 +552,9 @@ def apply(
550552
renormalize=renormalize,
551553
topk_group=topk_group,
552554
num_expert_group=num_expert_group,
553-
custom_routing_function=None)
555+
custom_routing_function=custom_routing_function,
556+
scoring_func=scoring_func,
557+
e_score_correction_bias=e_score_correction_bias)
554558

555559
return torch.ops.vllm.fused_marlin_moe(
556560
x,

0 commit comments

Comments
 (0)