Skip to content

Commit f759f5e

Browse files
committed
1. e2e qwen graph mock ep to 16
2. add rope_forward ut Signed-off-by: taoyuxiang <[email protected]>
1 parent ca1ba22 commit f759f5e

File tree

2 files changed

+75
-5
lines changed

2 files changed

+75
-5
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
"""
2222
import os
2323
from typing import Dict
24+
from unittest.mock import patch
2425

2526
from tests.e2e.conftest import VllmRunner
27+
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
2628

2729
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
2830

@@ -218,9 +220,11 @@ def _qwen_torchair_test_fixture(
218220
print(f"Generated text: {vllm_output[i][1]!r}")
219221

220222

221-
def test_e2e_qwen2_with_torchair():
222-
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)
223+
def test_e2e_qwen3_moe_with_torchair():
223224

225+
def stubbed_get_state(ep_size, with_prefill, is_deepseek_v3_r1):
226+
return _get_fused_moe_state(16, with_prefill, is_deepseek_v3_r1)
224227

225-
def test_e2e_qwen3_moe_with_torchair():
226-
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)
228+
with patch('vllm_ascend.ascend_forward_context._get_fused_moe_state',
229+
side_effect=stubbed_get_state):
230+
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)

tests/ut/ops/test_rotary_embedding.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import math
2+
from unittest import mock
23
from unittest.mock import MagicMock, patch
34

45
import pytest
56
import torch
7+
import torch_npu
68

79
from tests.ut.base import TestBase
10+
from vllm_ascend.ops.rotary_embedding import __set_cos_sin_cache # noqa E402
811
from vllm_ascend.ops.rotary_embedding import \
912
__set_cos_sin_cache as raw__set_cos_sin_cache
1013
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
1114
native_rope_deepseek_forward,
12-
rope_forward_oot, rotate_half,
15+
rope_forward, rope_forward_oot,
16+
rotate_half,
1317
yarn_find_correction_dim,
1418
yarn_get_mscale)
1519

@@ -363,3 +367,65 @@ def fake_register_buffer(name, tensor, persistent=True):
363367
assert buf.shape == expected_shape
364368
assert buf.device == device
365369
assert buf.dtype == torch.float32
370+
371+
372+
class DummyConfig:
373+
374+
class TorchairGraphConfig:
375+
enabled = True
376+
377+
torchair_graph_config = TorchairGraphConfig()
378+
379+
380+
class DummyModel:
381+
382+
def __init__(self, head_size, max_pos):
383+
self.head_size = head_size
384+
self.max_position_embeddings = max_pos
385+
self.cos = torch.randn(max_pos, head_size)
386+
self.sin = torch.randn(max_pos, head_size)
387+
388+
def embed(self, positions, weight):
389+
B, S = positions.shape
390+
return torch.ones(B, S, self.head_size) * 0.5
391+
392+
393+
@mock.patch("vllm_ascend.ops.rotary_embedding.get_ascend_config",
394+
return_value=DummyConfig())
395+
@mock.patch.object(torch_npu, "npu_apply_rotary_pos_emb")
396+
@mock.patch("vllm_ascend.ops.rotary_embedding.__set_cos_sin_cache")
397+
def test_rope_forward_output_shape(mock_set_cache, mock_npu_apply,
398+
mock_get_ascend_config):
399+
batch_size = 2
400+
seq_len = 4
401+
num_heads = 3
402+
head_size = 5
403+
404+
q = torch.randn(batch_size, seq_len, num_heads * head_size)
405+
k = torch.randn_like(q)
406+
407+
positions = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
408+
409+
model = DummyModel(head_size=head_size, max_pos=100)
410+
411+
def fake_apply_rotary(q_in, k_in, cos, sin):
412+
return q_in, k_in
413+
414+
mock_npu_apply.side_effect = fake_apply_rotary
415+
416+
q_out, k_out = rope_forward(
417+
model,
418+
positions=positions,
419+
query=q,
420+
key=k,
421+
offsets=None,
422+
is_neox_style_override=None,
423+
max_seq_len=None,
424+
is_prefill=False, # no rope_forward_oot
425+
is_qwen_torchair=True, # go rotary
426+
)
427+
428+
assert q_out.shape == (batch_size, 1, seq_len, num_heads * head_size)
429+
assert k_out.shape == (batch_size, 1, seq_len, num_heads * head_size)
430+
431+
mock_set_cache.assert_not_called()

0 commit comments

Comments
 (0)