Skip to content

Commit 06aa682

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent 35cb714 commit 06aa682

File tree

4 files changed

+16
-20
lines changed

4 files changed

+16
-20
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def setUp(self):
7272
self.mock_vllm_config.model_config.max_model_len = 640
7373
self.mock_vllm_config.cache_config.block_size = 64
7474
self.mock_device = 'cpu:0'
75-
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, self.mock_device)
75+
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
76+
self.mock_device)
7677

7778
def test_reorder_batch(self):
7879
mock_input_batch = MagicMock()
@@ -104,18 +105,14 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
104105
positions=torch.tensor([10, 10]),
105106
attn_mask=torch.ones((10, 10)),
106107
spec_attn_mask=None,
107-
attn_state=AscendAttentionState.PrefillNoCache
108-
)
108+
attn_state=AscendAttentionState.PrefillNoCache)
109109

110110
mock_nz_tensor = MagicMock()
111111
mock_model = MagicMock()
112112
mock_nd_to_nz_2d.return_value = mock_nz_tensor
113113
mock_npu_format_cast.return_value = mock_nz_tensor
114114

115-
self.builder.build(
116-
common_attn_metadata,
117-
mock_model
118-
)
115+
self.builder.build(common_attn_metadata, mock_model)
119116

120117
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
121118
@patch('torch_npu.npu_format_cast')
@@ -139,8 +136,7 @@ def test_build_chunked_prefill(self, mock_ascend_attention_state,
139136
positions=torch.tensor([10, 10]),
140137
attn_mask=torch.ones((15, 15)),
141138
spec_attn_mask=None,
142-
attn_state=AscendAttentionState.ChunkedPrefill
143-
)
139+
attn_state=AscendAttentionState.ChunkedPrefill)
144140

145141
mock_ascend_attention_state = MagicMock()
146142
mock_ascend_attention_state.PrefillNoCache = 0
@@ -169,8 +165,7 @@ def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
169165
positions=torch.tensor([10, 10]),
170166
attn_mask=torch.ones((15, 15)),
171167
spec_attn_mask=None,
172-
attn_state=AscendAttentionState.ChunkedPrefill
173-
)
168+
attn_state=AscendAttentionState.ChunkedPrefill)
174169
mock_model = MagicMock()
175170

176171
self.builder.build(common_attn_metadata, mock_model)

tests/ut/attention/test_mla_v1.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from unittest.mock import MagicMock, patch
22

3-
import numpy as np
43
import torch
54
from vllm.distributed.parallel_state import GroupCoordinator
65
from vllm.model_executor.layers.linear import LinearBase
@@ -195,9 +194,11 @@ def test_ascend_mla_metadata_builder_default(self):
195194
return_value=ascend_config):
196195
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
197196

198-
self.assertEqual(builder.block_size, mock_vllm_config.cache_config.block_size)
199-
self.assertEqual(builder.chunked_prefill_enabled,
200-
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
197+
self.assertEqual(builder.block_size,
198+
mock_vllm_config.cache_config.block_size)
199+
self.assertEqual(
200+
builder.chunked_prefill_enabled,
201+
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
201202
self.assertEqual(builder.torchair_graph_enabled, True)
202203

203204
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
@@ -363,7 +364,7 @@ def test_build_dummy(self, mock_ascend_config):
363364
num_reqs=3,
364365
num_actual_tokens=3,
365366
decode_token_per_req=1,
366-
actual_seq_lengths_q=[0,1,2],
367+
actual_seq_lengths_q=[0, 1, 2],
367368
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
368369
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
369370
)

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ def _get_graph_runner_block_tables(
173173
max_blocks = self.max_blocks
174174

175175
graph_block_tables = torch.zeros((num_seqs, max_blocks),
176-
dtype=block_tables.dtype,
177-
device=block_tables.device)
176+
dtype=block_tables.dtype,
177+
device=block_tables.device)
178178

179179
num_blocks = block_tables.size(1)
180180
if num_blocks <= max_blocks:

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def _get_graph_runner_block_tables(
282282
max_blocks = self.max_blocks
283283

284284
graph_block_tables = torch.zeros((num_seqs, max_blocks),
285-
dtype=block_tables.dtype,
286-
device=block_tables.device)
285+
dtype=block_tables.dtype,
286+
device=block_tables.device)
287287

288288
num_blocks = block_tables.size(1)
289289
if num_blocks <= max_blocks:

0 commit comments

Comments
 (0)