Skip to content

Commit 9faa4bf

Browse files
NicholasTaonicholastao
authored andcommitted
fix ut
Signed-off-by: taoyuxiang <[email protected]>
1 parent b7ec3b8 commit 9faa4bf

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

tests/ut/models/test_qwen3_moe.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ def test_packed_modules_mapping_structure(self):
5050
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
5151

5252

53+
class DummyRMSNorm:
54+
55+
def __init__(self, dim: int, eps: float = 1e-6):
56+
self.dim = dim
57+
self.eps = eps
58+
59+
def __call__(self, x):
60+
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
61+
denom = (mean_sq + self.eps).sqrt()
62+
return x / denom
63+
64+
5365
class TestCustomQwen3MoeAttention(unittest.TestCase):
5466

5567
def setUp(self):
@@ -70,8 +82,10 @@ def test_constant_input_normalization(self):
7082
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
7183
dtype=torch.float32)
7284

85+
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
86+
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
7387
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
74-
ones_qkv, self.q_size, self.kv_size, self.head_dim, self.rms_eps)
88+
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
7589

7690
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
7791

tests/ut/ops/test_rotary_embedding.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import torch
55

66
from tests.ut.base import TestBase
7-
from vllm_ascend.ops.rotary_embedding import (__set_cos_sin_cache,
8-
custom_rotary_embedding_enabled,
7+
from vllm_ascend.ops.rotary_embedding import \
8+
__set_cos_sin_cache as raw__set_cos_sin_cache
9+
from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
910
native_rope_deepseek_forward,
1011
rope_forward_oot, rotate_half,
1112
yarn_find_correction_dim,
@@ -327,7 +328,7 @@ def __init__(self, base, rotary_dim, max_position_embeddings):
327328
self.max_position_embeddings = max_position_embeddings
328329

329330
def _set_cos_sin_cache(self, seq_len, device, dtype):
330-
return __set_cos_sin_cache(self, seq_len, device, dtype)
331+
return raw__set_cos_sin_cache(self, seq_len, device, dtype)
331332

332333

333334
class TestSetCosSinCache(TestBase):

vllm_ascend/models/qwen3_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,15 @@ def __init__(
207207

208208
@staticmethod
209209
def normalize_qkv(qkv: torch.Tensor, q_size: int, kv_size: int,
210-
head_dim: int, rms_norm_eps: float):
210+
head_dim: int, q_norm, k_norm):
211211
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
212212

213213
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
214-
q_by_head = RMSNorm(head_dim, eps=rms_norm_eps)(q_by_head)
214+
q_by_head = q_norm(q_by_head)
215215
q = q_by_head.view(q.shape)
216216

217217
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // head_dim, head_dim)
218-
k_by_head = RMSNorm(head_dim, eps=rms_norm_eps)(k_by_head)
218+
k_by_head = k_norm(k_by_head)
219219
k = k_by_head.view(k.shape)
220220

221221
return q, k, v
@@ -228,7 +228,7 @@ def forward(
228228
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
229229
qkv, _ = self.qkv_proj(hidden_states)
230230
q, k, v = self.normalize_qkv(qkv, self.q_size, self.kv_size,
231-
self.head_dim, self.rms_norm_eps)
231+
self.head_dim, self.q_norm, self.k_norm)
232232

233233
if (self.torchair_graph_enabled and attn_metadata is not None and
234234
attn_metadata.attn_state == AscendAttentionState.DecodeOnly):

0 commit comments

Comments
 (0)