Skip to content

Commit 8a6d5ee

Browse files
committed
add qwen3_moe test_constant_input_normalization
Signed-off-by: taoyuxiang <[email protected]>
1 parent 9747c93 commit 8a6d5ee

File tree

2 files changed

+60
-16
lines changed

2 files changed

+60
-16
lines changed

tests/ut/models/test_qwen3_moe.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# limitations under the License.
1313
# This file is a part of the vllm-ascend project.
1414
#
15+
import math
16+
import unittest
1517

1618
import pytest
19+
import torch
1720
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
18-
19-
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
21+
from vllm_ascend.models.qwen3_moe import (CustomQwen3MoeAttention,
22+
CustomQwen3MoeForCausalLM)
2023

2124

2225
class TestCustomQwen3MoeForCausalLM:
@@ -44,3 +47,40 @@ def test_packed_modules_mapping_structure(self):
4447
]
4548
}
4649
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
50+
51+
52+
class TestNormalizeQKVWithFixedInput(unittest.TestCase):
53+
def setUp(self):
54+
self.batch = 2
55+
self.seq_len = 3
56+
self.q_size = 8
57+
self.kv_size = 8
58+
self.head_dim = 4
59+
self.rms_eps = 1e-6
60+
61+
total_dim = self.q_size + 2 * self.kv_size
62+
63+
self.qkv = torch.arange(
64+
self.batch * self.seq_len * total_dim,
65+
dtype=torch.float32
66+
).reshape(self.batch, self.seq_len, total_dim)
67+
68+
def test_constant_input_normalization(self):
69+
ones_qkv = torch.ones(
70+
(1, 1, self.q_size + 2 * self.kv_size),
71+
dtype=torch.float32
72+
)
73+
74+
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
75+
ones_qkv, self.q_size, self.kv_size, self.head_dim, self.rms_eps
76+
)
77+
78+
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
79+
80+
expected_q = torch.full((1, 1, self.q_size), norm_val)
81+
expected_k = torch.full((1, 1, self.kv_size), norm_val)
82+
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
83+
84+
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
85+
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
86+
self.assertTrue(torch.equal(v, expected_v))

vllm_ascend/models/qwen3_moe.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -205,26 +205,30 @@ def __init__(
205205
ascend_config = get_ascend_config()
206206
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
207207

208+
@staticmethod
209+
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
210+
head_dim: int, rms_norm_eps: float):
211+
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
212+
213+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
214+
q_by_head = RMSNorm(head_dim, eps=rms_norm_eps)(q_by_head)
215+
q = q_by_head.view(q.shape)
216+
217+
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
218+
k_by_head = RMSNorm(head_dim, eps=rms_norm_eps)(k_by_head)
219+
k = k_by_head.view(k.shape)
220+
221+
return q, k, v
222+
208223
def forward(
209224
self,
210225
positions: torch.Tensor,
211226
hidden_states: torch.Tensor,
212227
kv_cache: Optional[torch.Tensor] = None,
213228
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
214-
qkv, _ = self.qkv_proj(hidden_states)
215-
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
216-
# Add qk-norm
217-
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
218-
self.head_dim)
219-
220-
q_by_head = self.q_norm(q_by_head)
221-
q = q_by_head.view(q.shape)
222-
223-
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
224-
self.head_dim)
225-
226-
k_by_head = self.k_norm(k_by_head)
227-
k = k_by_head.view(k.shape)
229+
q, k, v = self.normalize_qkv(self.qkv_proj(hidden_states), self.q_size,
230+
self.kv_size, self.head_dim,
231+
self.rms_norm_eps)
228232

229233
if (self.torchair_graph_enabled and attn_metadata is not None and
230234
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):

0 commit comments

Comments
 (0)