Skip to content

Commit 52aff9e

Browse files
[main] [bugfix] Fix misjudging quantized/unquantized scenarios (#2627)
### What this PR does / why we need it? In a mixed-precision scenario, quant_config is not None, but MoE needs to perform unquantized computation; however, quantized computation is currently being used. Therefore, we put the with_quant logic into forward, avoid misjudging in mix-precision scenarios. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? e2e & ut - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@98ac0cb Signed-off-by: Pr0Wh1teGivee <[email protected]>
1 parent aadc75c commit 52aff9e

File tree

7 files changed

+62
-65
lines changed

7 files changed

+62
-65
lines changed

tests/ut/ops/test_fused_ops.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,6 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
543543
mock_get_forward_context):
544544

545545
mock_forward_context = MagicMock()
546-
mock_forward_context.with_quant = True
547546
mock_forward_context.fused_moe_state = FusedMoEState.MC2
548547
mock_get_forward_context.return_value = mock_forward_context
549548

@@ -587,10 +586,10 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
587586
group_list_type=1,
588587
w1_scale_bias=None,
589588
w2_scale_bias=None,
590-
topk_scales=None)
589+
topk_scales=None,
590+
with_quant=True)
591591

592592
mock_get_forward_context.assert_called()
593-
self.assertTrue(mock_forward_context.with_quant)
594593
self.assertEqual(mock_forward_context.fused_moe_state,
595594
FusedMoEState.MC2)
596595

@@ -602,19 +601,15 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
602601

603602
self.assertEqual(result.dtype, torch.bfloat16)
604603

605-
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
606604
@patch('vllm_ascend.ops.fused_moe.is_310p')
607605
@patch('torch_npu.npu_grouped_matmul')
608606
@patch('torch_npu.npu_swiglu')
609607
@patch('torch_npu.npu_dynamic_quant')
610-
def test_unified_apply_mlp_without_quantization(
611-
self, mock_npu_dynamic_quant, mock_npu_swiglu,
612-
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
613-
614-
mock_forward_context = MagicMock()
615-
mock_forward_context.with_quant = False
616-
mock_get_forward_context.return_value = mock_forward_context
617-
608+
def test_unified_apply_mlp_without_quantization(self,
609+
mock_npu_dynamic_quant,
610+
mock_npu_swiglu,
611+
mock_npu_grouped_matmul,
612+
mock_is_310p):
618613
mock_is_310p.return_value = False
619614

620615
mock_npu_grouped_matmul.side_effect = [[
@@ -639,10 +634,8 @@ def test_unified_apply_mlp_without_quantization(
639634
group_list_type=1,
640635
w1_scale_bias=None,
641636
w2_scale_bias=None,
642-
topk_scales=topk_scales)
643-
644-
mock_get_forward_context.assert_called()
645-
self.assertFalse(mock_forward_context.with_quant)
637+
topk_scales=topk_scales,
638+
with_quant=False)
646639

647640
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
648641
mock_npu_swiglu.assert_called_once()
@@ -698,10 +691,10 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
698691
group_list_type=1,
699692
w1_scale_bias=w1_scale_bias,
700693
w2_scale_bias=w2_scale_bias,
701-
topk_scales=None)
694+
topk_scales=None,
695+
with_quant=True)
702696

703697
mock_get_forward_context.assert_called()
704-
self.assertTrue(mock_forward_context.with_quant)
705698

706699
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
707700
mock_npu_swiglu.assert_called_once()
@@ -710,19 +703,13 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
710703
self.assertEqual(result.shape, hidden_states.shape)
711704
self.assertEqual(result.dtype, torch.bfloat16)
712705

713-
@patch('vllm_ascend.ops.fused_moe.get_forward_context')
714706
@patch('vllm_ascend.ops.fused_moe.is_310p')
715707
@patch('torch_npu.npu_grouped_matmul')
716708
@patch('torch_npu.npu_swiglu')
717709
@patch('torch_npu.npu_dynamic_quant')
718710
def test_unified_apply_mlp_without_quantization_310p(
719711
self, mock_npu_dynamic_quant, mock_npu_swiglu,
720-
mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context):
721-
722-
mock_forward_context = MagicMock()
723-
mock_forward_context.with_quant = False
724-
mock_get_forward_context.return_value = mock_forward_context
725-
712+
mock_npu_grouped_matmul, mock_is_310p):
726713
mock_is_310p.return_value = True
727714

728715
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
@@ -750,10 +737,9 @@ def test_unified_apply_mlp_without_quantization_310p(
750737
group_list_type=1,
751738
w1_scale_bias=None,
752739
w2_scale_bias=None,
753-
topk_scales=topk_scales)
740+
topk_scales=topk_scales,
741+
with_quant=False)
754742

755-
mock_get_forward_context.assert_called()
756-
self.assertFalse(mock_forward_context.with_quant)
757743
mock_is_310p.assert_called_once()
758744

759745
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)

tests/ut/ops/test_token_dispatcher.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ def test_token_dispatch_with_quant(self):
263263
"max_num_tokens": 100,
264264
"ep_size": 2,
265265
"num_experts": 128,
266-
"with_quant": True,
267266
}
268267
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
269268

@@ -460,8 +459,7 @@ def test_token_combine(self):
460459
def test_token_dispatch_with_quant(self):
461460
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
462461
num_experts=4,
463-
num_local_experts=2,
464-
with_quant=True)
462+
num_local_experts=2)
465463

466464
hidden_states = torch.randn(8, 16)
467465
topk_weights = torch.rand(8, 4)
@@ -476,7 +474,8 @@ def test_token_dispatch_with_quant(self):
476474
topk_weights=topk_weights,
477475
topk_ids=topk_ids,
478476
row_idx=self.row_idx,
479-
expert_map=expert_map)
477+
expert_map=expert_map,
478+
with_quant=True)
480479

481480
self.assertIsNotNone(result["hidden_states"])
482481
self.assertIsNotNone(result["group_list"])
@@ -486,8 +485,7 @@ def test_token_dispatch_with_quant(self):
486485
def test_token_dispatch_with_quant_no_active_tokens(self):
487486
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
488487
num_experts=4,
489-
num_local_experts=2,
490-
with_quant=True)
488+
num_local_experts=2)
491489

492490
self.mock_repeat_interleave.return_value = torch.tensor(
493491
[], dtype=torch.long)
@@ -505,7 +503,8 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
505503
topk_weights=topk_weights,
506504
topk_ids=topk_ids,
507505
row_idx=self.row_idx,
508-
expert_map=expert_map)
506+
expert_map=expert_map,
507+
with_quant=True)
509508

510509
self.assertIsNotNone(result["hidden_states"])
511510
self.assertIsNotNone(result["group_list"])

vllm_ascend/ascend_forward_context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ def set_ascend_forward_context(
9999
forward_context.fused_moe_state = fused_moe_state
100100
forward_context.in_profile_run = in_profile_run
101101

102-
with_quant = vllm_config.quant_config is not None
103-
forward_context.with_quant = with_quant
104102
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
105103
get_token_dispatcher
106104
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -408,19 +408,19 @@ def unquant_apply_mlp(
408408
return hidden_states
409409

410410

411-
def unified_apply_mlp(
412-
hidden_states: torch.Tensor,
413-
w1: torch.Tensor,
414-
w1_scale: torch.Tensor,
415-
w2: torch.Tensor,
416-
w2_scale: torch.Tensor,
417-
group_list: torch.Tensor,
418-
dynamic_scale: torch.Tensor = None,
419-
group_list_type: int = 1,
420-
w1_scale_bias: torch.Tensor = None,
421-
w2_scale_bias: torch.Tensor = None,
422-
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
423-
if get_forward_context().with_quant:
411+
def unified_apply_mlp(hidden_states: torch.Tensor,
412+
w1: torch.Tensor,
413+
w1_scale: torch.Tensor,
414+
w2: torch.Tensor,
415+
w2_scale: torch.Tensor,
416+
group_list: torch.Tensor,
417+
dynamic_scale: torch.Tensor = None,
418+
group_list_type: int = 1,
419+
w1_scale_bias: torch.Tensor = None,
420+
w2_scale_bias: torch.Tensor = None,
421+
topk_scales: Optional[torch.Tensor] = None,
422+
with_quant: bool = False) -> torch.Tensor:
423+
if with_quant:
424424
return quant_apply_mlp(hidden_states=hidden_states,
425425
w1=w1,
426426
w1_scale=w1_scale,
@@ -457,7 +457,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
457457
shared_gate_up: Optional[Any] = None,
458458
shared_dequant_scale: Optional[Any] = None,
459459
mc2_mask: Optional[torch.Tensor] = None,
460-
apply_router_weight_on_input: bool = False):
460+
apply_router_weight_on_input: bool = False,
461+
with_quant: bool = False):
461462
token_dispatcher = get_forward_context().token_dispatcher
462463

463464
results = token_dispatcher.token_dispatch(
@@ -472,7 +473,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
472473
shared_gate_up=shared_gate_up,
473474
shared_dequant_scale=shared_dequant_scale,
474475
mc2_mask=mc2_mask,
475-
apply_router_weight_on_input=apply_router_weight_on_input)
476+
apply_router_weight_on_input=apply_router_weight_on_input,
477+
with_quant=with_quant)
476478

477479
expert_output = unified_apply_mlp(
478480
hidden_states=results["hidden_states"],
@@ -485,7 +487,8 @@ def unified_fused_experts_eager(hidden_states: torch.Tensor,
485487
group_list_type=results.get("group_list_type"),
486488
w1_scale_bias=w1_scale_bias,
487489
w2_scale_bias=w2_scale_bias,
488-
topk_scales=results.get("topk_scales"))
490+
topk_scales=results.get("topk_scales"),
491+
with_quant=with_quant)
489492
final_hidden_states = token_dispatcher.token_combine(expert_output)
490493
return final_hidden_states
491494

@@ -577,7 +580,8 @@ def apply(
577580
expert_map=expert_map,
578581
shared_experts=shared_experts,
579582
mc2_mask=kwargs.get(
580-
"mc2_mask", None))
583+
"mc2_mask", None),
584+
with_quant=False)
581585

582586

583587
class AscendFusedMoE(FusedMoE):
@@ -761,16 +765,14 @@ def __init__(
761765

762766
ep_size = (get_ep_group().world_size if
763767
vllm_config.parallel_config.enable_expert_parallel else 1)
764-
with_quant = quant_config is not None
765768
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
766769
setup_token_dispatchers
767770
setup_token_dispatchers(
768771
ep_size,
769772
top_k=self.top_k,
770773
num_experts=self.global_num_experts,
771774
num_global_redundant_experts=self.global_redundant_expert_num,
772-
num_local_experts=self.local_num_experts,
773-
with_quant=with_quant)
775+
num_local_experts=self.local_num_experts)
774776

775777
def naive_multicast(self, x: torch.Tensor,
776778
cu_tokens_across_dp_cpu: torch.Tensor):

vllm_ascend/ops/moe_dispatcher/token_dispatcher.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ def __init__(self, **kwargs) -> None:
490490
"""
491491
self.top_k = kwargs.get("top_k", 0)
492492
self.num_experts = kwargs.get("num_experts", 0)
493-
self.with_quant = kwargs.get("with_quant", False)
494493

495494
@property
496495
def ep_group(self):
@@ -518,7 +517,8 @@ def token_dispatch(self,
518517
shared_gate_up: Optional[torch.Tensor] = None,
519518
shared_dequant_scale: Optional[torch.Tensor] = None,
520519
mc2_mask: Optional[torch.Tensor] = None,
521-
apply_router_weight_on_input: bool = False):
520+
apply_router_weight_on_input: bool = False,
521+
with_quant: bool = False):
522522
raise NotImplementedError("Dispatch function not implemented.")
523523

524524
@abstractmethod
@@ -555,6 +555,7 @@ def __init__(self, **kwargs):
555555
self.topk_weights = None
556556
self.shared_experts = None
557557
self.mc2_mask = None
558+
self.with_quant = False
558559

559560
def get_dispatch_mc2_kwargs(
560561
self,
@@ -615,7 +616,9 @@ def token_dispatch(self,
615616
shared_gate_up: Optional[torch.Tensor] = None,
616617
shared_dequant_scale: Optional[torch.Tensor] = None,
617618
mc2_mask: Optional[torch.Tensor] = None,
618-
apply_router_weight_on_input: bool = False):
619+
apply_router_weight_on_input: bool = False,
620+
with_quant: bool = False):
621+
self.with_quant = with_quant
619622
self.expert_map = expert_map
620623
self.topk_ids = topk_ids
621624
self.topk_weights = topk_weights
@@ -738,6 +741,7 @@ def __init__(self, **kwargs):
738741
self.expert_map = None
739742
self.topk_weights = None
740743
self.topk_ids = None
744+
self.with_quant = False
741745

742746
def token_dispatch(self,
743747
hidden_states: torch.Tensor,
@@ -751,7 +755,9 @@ def token_dispatch(self,
751755
shared_gate_up: Optional[torch.Tensor] = None,
752756
shared_dequant_scale: Optional[torch.Tensor] = None,
753757
mc2_mask: Optional[torch.Tensor] = None,
754-
apply_router_weight_on_input: bool = False):
758+
apply_router_weight_on_input: bool = False,
759+
with_quant: bool = False):
760+
self.with_quant = with_quant
755761
self.original_shape = hidden_states.shape
756762

757763
num_tokens = hidden_states.shape[:-1].numel()
@@ -922,7 +928,8 @@ def token_dispatch(self,
922928
shared_gate_up: Optional[torch.Tensor] = None,
923929
shared_dequant_scale: Optional[torch.Tensor] = None,
924930
mc2_mask: Optional[torch.Tensor] = None,
925-
apply_router_weight_on_input: bool = False):
931+
apply_router_weight_on_input: bool = False,
932+
with_quant: bool = False):
926933
self.apply_router_weight_on_input = apply_router_weight_on_input
927934
if self.apply_router_weight_on_input:
928935
assert (topk_weights.dim() == 2
@@ -980,6 +987,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
980987

981988
def __init__(self, **kwargs):
982989
super().__init__(**kwargs)
990+
self.with_quant = False
983991
self.num_local_experts = kwargs.get("num_local_experts", 0)
984992
self.num_global_redundant_experts = kwargs.get(
985993
"num_global_redundant_experts", 0)
@@ -1032,7 +1040,9 @@ def token_dispatch(self,
10321040
shared_gate_up: Optional[torch.Tensor] = None,
10331041
shared_dequant_scale: Optional[torch.Tensor] = None,
10341042
mc2_mask: Optional[torch.Tensor] = None,
1035-
apply_router_weight_on_input: bool = False):
1043+
apply_router_weight_on_input: bool = False,
1044+
with_quant: bool = False):
1045+
self.with_quant = with_quant
10361046
self.hidden_shape = hidden_states.shape
10371047
self.topk_weights = topk_weights
10381048
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ def apply(
308308
shared_experts=shared_experts,
309309
shared_gate_up=shared_gate_up,
310310
shared_dequant_scale=shared_dequant_scale,
311-
mc2_mask=kwargs.get("mc2_mask", None))
311+
mc2_mask=kwargs.get("mc2_mask", None),
312+
with_quant=True)
312313

313314
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
314315
group_num, k, n = weight.shape

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ def apply(
406406
shared_experts=shared_experts,
407407
shared_gate_up=shared_gate_up,
408408
shared_dequant_scale=shared_dequant_scale,
409-
mc2_mask=kwargs.get("mc2_mask", None))
409+
mc2_mask=kwargs.get("mc2_mask", None),
410+
with_quant=True)
410411

411412
def process_weights_after_loading(self, layer):
412413
if self.transpose_weight:

0 commit comments

Comments
 (0)