Skip to content

Commit 7bec1a9

Browse files
authored
qwen3_moe/qwen25 support torchair graph (#2403)
### What this PR does / why we need it? Added support for the TorchAir graph mode in qwen3_moe and qwen2.5 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ```bash llm = LLM( model=model, tensor_parallel_size=GPUs_per_dp_rank, enforce_eager=False, enable_expert_parallel=True, max_model_len=4096, max_num_seqs=16, trust_remote_code=trust_remote_code, gpu_memory_utilization=0.4, additional_config={ "torchair_graph_config": { "enabled": True, "use_cached_graph": False, "graph_batch_sizes_init": False, "graph_batch_sizes": [16] }, "ascend_scheduler_config": { "enabled": True, "chunked_prefill_enabled":True, }, "refresh": True, }, ) ``` - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@b87cb97 Signed-off-by: taoyuxiang <[email protected]>
1 parent 31ae249 commit 7bec1a9

File tree

9 files changed

+1123
-9
lines changed

9 files changed

+1123
-9
lines changed

tests/e2e/multicard/test_torchair_graph_mode.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,65 @@ def test_e2e_pangu_with_torchair():
162162
},
163163
}
164164
_pangu_torchair_test_fixture(additional_config)
165+
166+
167+
def _qwen_torchair_test_fixture(
168+
model,
169+
tp,
170+
enable_expert_parallel,
171+
):
172+
# The current access control does not support 16 cards,
173+
# so the MC2 operator in Qwen's graph mode cannot run.
174+
# Once 16-card support is available,
175+
# this e2e can be switched to graph mode.
176+
example_prompts = [
177+
"Hello, my name is",
178+
"The president of the United States is",
179+
"The capital of France is",
180+
"The future of AI is",
181+
]
182+
183+
additional_config = {
184+
"torchair_graph_config": {
185+
"enabled": False,
186+
},
187+
"ascend_scheduler_config": {
188+
"enabled": True,
189+
},
190+
"refresh": True,
191+
}
192+
193+
with VllmRunner(
194+
model,
195+
dtype="half",
196+
tensor_parallel_size=tp,
197+
distributed_executor_backend="mp",
198+
enforce_eager=True,
199+
additional_config=additional_config,
200+
enable_expert_parallel=enable_expert_parallel,
201+
) as vllm_model:
202+
# use greedy sampler to make sure the generated results are fix
203+
vllm_output = vllm_model.generate_greedy(example_prompts, 5)
204+
205+
# NOTE: vllm-ascend/pangu-pro-moe-pruing is only part of PanguProMoE
206+
# with 2 hidden layers, thus the golden results seems inaccurate.
207+
# This will only change if accuracy changes with the official weights
208+
# of PanguProMoE.
209+
golden_results = [
210+
'Hello, my name is Remempondeprecatedmiot忱',
211+
'The president of the United States is Remem下的一个 rever ceremoni Segnali',
212+
'The capital of France is Rememvoud administrativ Remem投',
213+
'The future of AI isotope Segnali Zoeken精细化 supus',
214+
]
215+
216+
assert len(golden_results) == len(vllm_output)
217+
for i in range(len(vllm_output)):
218+
print(f"Generated text: {vllm_output[i][1]!r}")
219+
220+
221+
def test_e2e_qwen2_with_torchair():
222+
_qwen_torchair_test_fixture("Qwen/Qwen2.5-0.5B-Instruct", 2, False)
223+
224+
225+
def test_e2e_qwen3_moe_with_torchair():
226+
_qwen_torchair_test_fixture("Qwen/Qwen3-30B-A3B", 2, True)

tests/ut/models/test_qwen3_moe.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212
# limitations under the License.
1313
# This file is a part of the vllm-ascend project.
1414
#
15+
import math
16+
import unittest
1517

1618
import pytest
19+
import torch
1720
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
1821

1922
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
23+
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
2024

2125

2226
class TestCustomQwen3MoeForCausalLM:
@@ -44,3 +48,51 @@ def test_packed_modules_mapping_structure(self):
4448
]
4549
}
4650
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
51+
52+
53+
class DummyRMSNorm:
54+
55+
def __init__(self, dim: int, eps: float = 1e-6):
56+
self.dim = dim
57+
self.eps = eps
58+
59+
def __call__(self, x):
60+
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
61+
denom = (mean_sq + self.eps).sqrt()
62+
return x / denom
63+
64+
65+
class TestCustomQwen3MoeAttention(unittest.TestCase):
66+
67+
def setUp(self):
68+
self.batch = 2
69+
self.seq_len = 3
70+
self.q_size = 8
71+
self.kv_size = 8
72+
self.head_dim = 4
73+
self.rms_eps = 1e-6
74+
75+
total_dim = self.q_size + 2 * self.kv_size
76+
77+
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
78+
dtype=torch.float32).reshape(
79+
self.batch, self.seq_len, total_dim)
80+
81+
def test_constant_input_normalization(self):
82+
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
83+
dtype=torch.float32)
84+
85+
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
86+
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
87+
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
88+
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
89+
90+
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
91+
92+
expected_q = torch.full((1, 1, self.q_size), norm_val)
93+
expected_k = torch.full((1, 1, self.kv_size), norm_val)
94+
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
95+
96+
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
97+
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
98+
self.assertTrue(torch.equal(v, expected_v))

tests/ut/test_ascend_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_check_ascend_config_wrong_case(self):
232232

233233
def test_check_torchair_supported(self):
234234
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
235-
('qwen', False), ('llama', False)]
235+
('qwen', True), ('llama', False)]
236236
for model_type, expected_output in test_cases:
237237
self.assertEqual(_check_torchair_supported(model_type),
238238
expected_output)

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from vllm.logger import logger
1919

20-
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2"]
20+
TORCHAIR_MODEL_LIST = ["deepseek", "pangu", "kimi_k2", "qwen"]
2121

2222

2323
def _check_torchair_supported(model_type: str):
@@ -162,7 +162,7 @@ def check_ascend_config(vllm_config, enforce_eager):
162162
else:
163163
# torchair_graph case
164164
if ascend_config.torchair_graph_config.enabled:
165-
# torchair_graph is supported for deepseek/pangu model only.
165+
# torchair_graph is supported for deepseek/pangu/qwen model only.
166166
if vllm_config.model_config:
167167
model_type = vllm_config.model_config.hf_config.model_type
168168
if not _check_torchair_supported(model_type):

vllm_ascend/ops/rotary_embedding.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from typing import Optional, Tuple
2020

2121
import torch
22+
import torch.nn.functional as F
23+
import torch_npu
2224
from vllm.model_executor.layers.rotary_embedding import (
2325
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
2426

@@ -37,17 +39,18 @@ def rope_forward_oot(
3739
query: torch.Tensor,
3840
key: torch.Tensor,
3941
offsets: Optional[torch.Tensor] = None,
40-
is_neox_style_override: Optional[bool] = None
42+
is_neox_style_override: Optional[bool] = None,
43+
is_qwen_torchair: Optional[bool] = False,
4144
) -> Tuple[torch.Tensor, torch.Tensor]:
42-
if get_ascend_config().torchair_graph_config.enabled:
45+
if get_ascend_config(
46+
).torchair_graph_config.enabled and not is_qwen_torchair:
4347
return self.forward_native(
4448
positions,
4549
query,
4650
key,
4751
offsets,
4852
)
4953

50-
import torch_npu
5154
query_shape, key_shape = query.shape, key.shape
5255
if self.cos_sin_cache.device != query.device:
5356
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
@@ -246,6 +249,92 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
246249
self.register_buffer("sin_cached", sin_cached, persistent=False)
247250

248251

252+
def __set_cos_sin_cache(self, seq_len, device, dtype):
253+
inv_freq = 1.0 / (self.base**(torch.arange(
254+
0, self.rotary_dim, 2, device=device, dtype=torch.float32) *
255+
(1 / self.rotary_dim)))
256+
self.register_buffer("inv_freq", inv_freq)
257+
258+
t = torch.arange(self.max_position_embeddings,
259+
device=self.inv_freq.device,
260+
dtype=torch.float32)
261+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
262+
263+
emb = torch.cat((freqs, freqs), dim=-1)
264+
self.register_buffer("cos", emb.cos().to(dtype=dtype), persistent=False)
265+
self.register_buffer("sin", emb.sin().to(dtype=dtype), persistent=False)
266+
self.embed = F.embedding
267+
268+
269+
_original_re_init = RotaryEmbedding.__init__
270+
271+
272+
def qwen_rope_init_func(
273+
self,
274+
head_size: int,
275+
rotary_dim: int,
276+
max_position_embeddings: int,
277+
base: float,
278+
is_neox_style: bool,
279+
dtype: torch.dtype,
280+
) -> None:
281+
_original_re_init(self, head_size, rotary_dim, max_position_embeddings,
282+
base, is_neox_style, dtype)
283+
if get_ascend_config().torchair_graph_config.enabled:
284+
__set_cos_sin_cache(self,
285+
seq_len=max_position_embeddings,
286+
device="npu",
287+
dtype=dtype)
288+
289+
290+
def rope_forward(
291+
self,
292+
positions: torch.Tensor,
293+
query: torch.Tensor,
294+
key: torch.Tensor,
295+
offsets: Optional[torch.Tensor] = None,
296+
is_neox_style_override: Optional[bool] = None,
297+
max_seq_len: Optional[int] = None,
298+
is_prefill: Optional[bool] = True,
299+
is_qwen_torchair: Optional[bool] = False,
300+
):
301+
if get_ascend_config().torchair_graph_config.enabled \
302+
and is_qwen_torchair and not is_prefill:
303+
if max_seq_len is not None and torch.gt(max_seq_len,
304+
self.max_position_embeddings):
305+
__set_cos_sin_cache(self,
306+
seq_len=max_seq_len,
307+
device=query.device,
308+
dtype=torch.float32)
309+
310+
# bsnd/bnsd
311+
if positions is not None:
312+
cos = self.embed(positions, self.cos)
313+
sin = self.embed(positions, self.sin)
314+
self.cos_embed = cos
315+
self.sin_embed = sin
316+
else:
317+
cos = self.cos_embed
318+
sin = self.sin_embed
319+
320+
query = query.view(*query.shape[:-1], -1, self.head_size).contiguous()
321+
key = key.view(*key.shape[:-1], -1, self.head_size).contiguous()
322+
323+
cos = cos.unsqueeze(-2).unsqueeze(-2)
324+
sin = sin.unsqueeze(-2).unsqueeze(-2)
325+
326+
query = query.unsqueeze(1)
327+
key = key.unsqueeze(1)
328+
329+
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(
330+
query, key, cos, sin)
331+
return q_embed.flatten(-2), k_embed.flatten(-2)
332+
else:
333+
return rope_forward_oot(self, positions, query, key, offsets,
334+
is_neox_style_override,
335+
is_qwen_torchair) # type: ignore
336+
337+
249338
def deepseek_rope_init_func(
250339
self,
251340
head_size: int,
@@ -283,7 +372,8 @@ def deepseek_rope_init_func(
283372
device="npu")
284373

285374

286-
RotaryEmbedding.forward_oot = rope_forward_oot
375+
RotaryEmbedding.__init__ = qwen_rope_init_func
376+
RotaryEmbedding.forward_oot = rope_forward
287377

288378
# Note: we adopt the native huggingface deepseek rope initialization code from
289379
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for

0 commit comments

Comments
 (0)