Skip to content

Commit b7ec3b8

Browse files
committed
fix yapf
Signed-off-by: taoyuxiang <[email protected]>
1 parent 7d701cd commit b7ec3b8

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

tests/ut/models/test_qwen3_moe.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
import torch
2020
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
21+
2122
from vllm_ascend.models.qwen3_moe import (CustomQwen3MoeAttention,
2223
CustomQwen3MoeForCausalLM)
2324

@@ -49,7 +50,8 @@ def test_packed_modules_mapping_structure(self):
4950
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
5051

5152

52-
class TestNormalizeQKVWithFixedInput(unittest.TestCase):
53+
class TestCustomQwen3MoeAttention(unittest.TestCase):
54+
5355
def setUp(self):
5456
self.batch = 2
5557
self.seq_len = 3
@@ -60,20 +62,16 @@ def setUp(self):
6062

6163
total_dim = self.q_size + 2 * self.kv_size
6264

63-
self.qkv = torch.arange(
64-
self.batch * self.seq_len * total_dim,
65-
dtype=torch.float32
66-
).reshape(self.batch, self.seq_len, total_dim)
65+
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
66+
dtype=torch.float32).reshape(
67+
self.batch, self.seq_len, total_dim)
6768

6869
def test_constant_input_normalization(self):
69-
ones_qkv = torch.ones(
70-
(1, 1, self.q_size + 2 * self.kv_size),
71-
dtype=torch.float32
72-
)
70+
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
71+
dtype=torch.float32)
7372

7473
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
75-
ones_qkv, self.q_size, self.kv_size, self.head_dim, self.rms_eps
76-
)
74+
ones_qkv, self.q_size, self.kv_size, self.head_dim, self.rms_eps)
7775

7876
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
7977

tests/ut/ops/test_rotary_embedding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,7 @@ def test_set_cos_sin_cache(self):
341341
# mock out register_buffer
342342
model.register_buffer = MagicMock()
343343
# call the private method via name mangling
344-
model._set_cos_sin_cache(seq_len=8,
345-
device="cpu",
346-
dtype=torch.float32)
344+
model._set_cos_sin_cache(seq_len=8, device="cpu", dtype=torch.float32)
347345
# expect three calls: inv_freq, cos, sin
348346
assert model.register_buffer.call_count == 3
349347
names = [call.args[0] for call in model.register_buffer.call_args_list]

0 commit comments

Comments
 (0)