Skip to content

Commit 4ff422c

Browse files
authored
[CI][Bugfix] Quickfix for DPMetaData (#3234)
### What this PR does / why we need it? Fix `dpmetadata` and `Qwen3MoeSparseMoeBlock` break introduced by vllm-project/vllm@26a7a33#diff-c1550d0a38469d039370567d8981969530cbfffc7302cd1778e7c2c8a9322dea NOTE: we maintain a different sp in vllm-ascend with vllm, thus we can just use `cu_tokens_across_sp(1)` as `cu_tokens_across_dp_cpu` close #3236, #3239 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: vllm-project/vllm@releases/v0.11.0 --------- Signed-off-by: MengqingCao <[email protected]>
1 parent f2d8493 commit 4ff422c

File tree

7 files changed

+59
-23
lines changed

7 files changed

+59
-23
lines changed

tests/ut/ops/test_fused_moe_prepare_and_finalize.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
FusedMoEPrepareAndFinalizeWithAll2All,
99
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
1010
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
11+
from vllm_ascend.utils import vllm_version_is
1112

1213

1314
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
@@ -230,8 +231,12 @@ def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
230231
mock_get_dp_group):
231232
# Mock forward context with DP metadata
232233
mock_context = MagicMock()
233-
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
234-
[2, 5, 7])
234+
if vllm_version_is("0.10.2"):
235+
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
236+
[2, 5, 7])
237+
else:
238+
mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor(
239+
[2, 5, 7])
235240
mock_get_forward_context.return_value = mock_context
236241

237242
# Setup DP group mock

tests/ut/ops/test_fused_ops.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
AscendUnquantizedFusedMoEMethod)
2929
from vllm_ascend.ops.moe.experts_selector import select_experts
3030
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
31-
from vllm_ascend.utils import AscendSocVersion, adapt_patch
31+
from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is
3232

3333
adapt_patch(True)
3434

@@ -93,14 +93,18 @@ def mock_finalize(hidden_states, **kwargs):
9393

9494
mock_moe_comm_method.finalize.side_effect = mock_finalize
9595

96-
mock_forward_context_obj = MagicMock(
97-
moe_comm_method=mock_moe_comm_method,
98-
moe_comm_type=MoECommType.MC2,
99-
max_tokens_across_dp=10,
100-
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
101-
mc2_mask=torch.zeros(16, dtype=torch.bool),
102-
padded_num_tokens=16,
103-
with_quant=False)
96+
if vllm_version_is("0.10.2"):
97+
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
98+
else:
99+
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
100+
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
101+
moe_comm_type=MoECommType.MC2,
102+
max_tokens_across_dp=10,
103+
dp_metadata=dp_metadata,
104+
mc2_mask=torch.zeros(
105+
16, dtype=torch.bool),
106+
padded_num_tokens=16,
107+
with_quant=False)
104108

105109
with patch('torch.distributed.get_rank', return_value=0), \
106110
patch('torch.distributed.get_world_size', return_value=4), \

tests/ut/torchair/ops/test_torchair_fused_moe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
2727
from vllm_ascend.torchair.ops.torchair_fused_moe import (
2828
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
29-
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
29+
from vllm_ascend.utils import adapt_patch # noqa E402
30+
from vllm_ascend.utils import AscendSocVersion, vllm_version_is
3031

3132
adapt_patch(True)
3233

@@ -53,6 +54,10 @@ def mock_dp_and_tp_group(mocker):
5354
@pytest.fixture
5455
def mock_dist_env(mocker: MockerFixture):
5556
# init dist env patch
57+
if vllm_version_is("0.10.2"):
58+
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
59+
else:
60+
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
5661

5762
with patch('torch.distributed.get_rank', return_value=0), \
5863
patch('torch.distributed.get_world_size', return_value=4), \
@@ -80,7 +85,7 @@ def mock_dist_env(mocker: MockerFixture):
8085
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
8186
return_value=MagicMock(
8287
max_tokens_across_dp=10,
83-
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
88+
dp_metadata=dp_metadata,
8489
)), \
8590
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
8691
return_value=MagicMock(

vllm_ascend/models/qwen3_moe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
4848

4949
from vllm_ascend.ops.fused_moe import AscendFusedMoE
50+
from vllm_ascend.utils import vllm_version_is
5051

5152

5253
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -169,9 +170,14 @@ def __init__(
169170
quant_config=quant_config,
170171
prefix=f"{prefix}.mlp")
171172
else:
172-
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
173-
quant_config=quant_config,
174-
prefix=f"{prefix}.mlp")
173+
if vllm_version_is("0.10.2"):
174+
self.mlp = Qwen3MoeSparseMoeBlock(
175+
config=config,
176+
quant_config=quant_config,
177+
prefix=f"{prefix}.mlp")
178+
else:
179+
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
180+
prefix=f"{prefix}.mlp")
175181
else:
176182
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
177183
intermediate_size=config.intermediate_size,

vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from vllm.forward_context import get_forward_context
2727
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
2828

29+
from vllm_ascend.utils import vllm_version_is
30+
2931

3032
class FusedMoEPrepareAndFinalize(ABC):
3133
"""
@@ -414,8 +416,12 @@ def prepare(self,
414416
self.enable_shared_expert_dp = enable_shared_expert_dp
415417

416418
if self.moe_config.dp_size > 1:
417-
self.cu_tokens_across_dp_cpu = get_forward_context(
418-
).dp_metadata.cu_tokens_across_dp_cpu
419+
if vllm_version_is("0.10.2"):
420+
self.cu_tokens_across_dp_cpu = get_forward_context(
421+
).dp_metadata.cu_tokens_across_dp_cpu
422+
else:
423+
self.cu_tokens_across_dp_cpu = get_forward_context(
424+
).dp_metadata.cu_tokens_across_sp(1)
419425
hidden_states = self._naive_multicast(hidden_states,
420426
self.cu_tokens_across_dp_cpu)
421427
if rm_router_logits:

vllm_ascend/torchair/models/qwen3_moe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from vllm_ascend.ops.fused_moe import AscendFusedMoE
5757
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
5858
init_metadata_for_sp)
59+
from vllm_ascend.utils import vllm_version_is
5960

6061

6162
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -311,9 +312,14 @@ def __init__(
311312
quant_config=quant_config,
312313
prefix=f"{prefix}.mlp")
313314
else:
314-
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
315-
quant_config=quant_config,
316-
prefix=f"{prefix}.mlp")
315+
if vllm_version_is("0.10.2"):
316+
self.mlp = Qwen3MoeSparseMoeBlock(
317+
config=config,
318+
quant_config=quant_config,
319+
prefix=f"{prefix}.mlp")
320+
else:
321+
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
322+
prefix=f"{prefix}.mlp")
317323
else:
318324
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
319325
intermediate_size=config.intermediate_size,

vllm_ascend/torchair/ops/torchair_fused_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,8 +1242,12 @@ def forward(self,
12421242
router_logits = get_dp_group().all_gather(router_logits, 0)
12431243

12441244
elif fused_moe_state == FusedMoEState.NaiveMulticast:
1245-
cu_tokens_across_dp_cpu = get_forward_context(
1246-
).dp_metadata.cu_tokens_across_dp_cpu
1245+
if vllm_version_is("0.10.2"):
1246+
cu_tokens_across_dp_cpu = get_forward_context(
1247+
).dp_metadata.cu_tokens_across_dp_cpu
1248+
else:
1249+
cu_tokens_across_dp_cpu = get_forward_context(
1250+
).dp_metadata.cu_tokens_across_sp(1)
12471251
hidden_states = self.naive_multicast(hidden_states,
12481252
cu_tokens_across_dp_cpu)
12491253
if self.rm_router_logits:

0 commit comments

Comments
 (0)