18
18
import pytest
19
19
import torch
20
20
from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
21
+
21
22
from vllm_ascend .models .qwen3_moe import (CustomQwen3MoeAttention ,
22
23
CustomQwen3MoeForCausalLM )
23
24
@@ -49,7 +50,8 @@ def test_packed_modules_mapping_structure(self):
49
50
assert CustomQwen3MoeForCausalLM .packed_modules_mapping == expected_mapping
50
51
51
52
52
- class TestNormalizeQKVWithFixedInput (unittest .TestCase ):
53
+ class TestCustomQwen3MoeAttention (unittest .TestCase ):
54
+
53
55
def setUp (self ):
54
56
self .batch = 2
55
57
self .seq_len = 3
@@ -60,20 +62,16 @@ def setUp(self):
60
62
61
63
total_dim = self .q_size + 2 * self .kv_size
62
64
63
- self .qkv = torch .arange (
64
- self .batch * self .seq_len * total_dim ,
65
- dtype = torch .float32
66
- ).reshape (self .batch , self .seq_len , total_dim )
65
+ self .qkv = torch .arange (self .batch * self .seq_len * total_dim ,
66
+ dtype = torch .float32 ).reshape (
67
+ self .batch , self .seq_len , total_dim )
67
68
68
69
def test_constant_input_normalization (self ):
69
- ones_qkv = torch .ones (
70
- (1 , 1 , self .q_size + 2 * self .kv_size ),
71
- dtype = torch .float32
72
- )
70
+ ones_qkv = torch .ones ((1 , 1 , self .q_size + 2 * self .kv_size ),
71
+ dtype = torch .float32 )
73
72
74
73
q , k , v = CustomQwen3MoeAttention .normalize_qkv (
75
- ones_qkv , self .q_size , self .kv_size , self .head_dim , self .rms_eps
76
- )
74
+ ones_qkv , self .q_size , self .kv_size , self .head_dim , self .rms_eps )
77
75
78
76
norm_val = 1.0 / math .sqrt (1.0 + self .rms_eps )
79
77
0 commit comments