Skip to content

Commit da2d4da

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact attn metadata build Signed-off-by: weiguihua2 <[email protected]> refact model runner Signed-off-by: weiguihua2 <[email protected]> refact model runner Signed-off-by: weiguihua2 <[email protected]> refact model runner Signed-off-by: weiguihua2 <[email protected]> refact model runner Signed-off-by: weiguihua2 <[email protected]>
1 parent 9554116 commit da2d4da

File tree

13 files changed

+598
-390
lines changed

13 files changed

+598
-390
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AscendAttentionState,
1010
AscendMetadata,
1111
CommonAttentionState)
12+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
1213

1314

1415
class TestAscendAttentionBackend(TestBase):
@@ -67,8 +68,12 @@ def test_copy_blocks(self):
6768
class TestAscendAttentionMetadataBuilder(TestBase):
6869

6970
def setUp(self):
70-
self.mock_runner = MagicMock()
71-
self.builder = AscendAttentionMetadataBuilder(self.mock_runner)
71+
self.mock_vllm_config = MagicMock()
72+
self.mock_vllm_config.model_config.max_model_len = 640
73+
self.mock_vllm_config.cache_config.block_size = 64
74+
self.mock_device = 'cpu:0'
75+
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
76+
self.mock_device)
7277

7378
def test_reorder_batch(self):
7479
mock_input_batch = MagicMock()
@@ -86,31 +91,28 @@ def test_reorder_batch(self):
8691
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
8792
mock_npu_format_cast,
8893
mock_ascend_metadata):
89-
num_reqs = 2
90-
num_actual_tokens = 10
91-
max_query_len = 5
92-
93-
self.mock_runner.input_batch.block_table = [MagicMock()]
94-
self.mock_runner.input_batch.block_table[
95-
0].get_device_tensor.return_value = torch.zeros((10, 10))
96-
self.mock_runner.max_num_blocks_per_req = 10
97-
self.mock_runner.query_lens = torch.tensor([3, 4])
98-
self.mock_runner.seq_lens_cpu = torch.tensor([5, 6])
99-
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
100-
self.mock_runner.device = 'cpu:0'
101-
self.mock_runner.attn_mask = torch.ones((10, 10))
102-
self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache
103-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7])
94+
common_attn_metadata = AscendCommonAttentionMetadata(
95+
query_start_loc=torch.tensor([0, 3, 7]),
96+
query_start_loc_cpu=torch.tensor([0, 3, 7]),
97+
seq_lens_cpu=torch.tensor([5, 6]),
98+
num_reqs=2,
99+
num_actual_tokens=10,
100+
max_query_len=5,
101+
decode_token_per_req=torch.tensor([1, 1]),
102+
block_table_tensor=torch.zeros((10, 10)),
103+
slot_mapping_cpu=torch.tensor(range(20)),
104+
actual_seq_lengths_q=torch.tensor([0, 1]),
105+
positions=torch.tensor([10, 10]),
106+
attn_mask=torch.ones((10, 10)),
107+
spec_attn_mask=None,
108+
attn_state=AscendAttentionState.PrefillNoCache)
104109

105110
mock_nz_tensor = MagicMock()
111+
mock_model = MagicMock()
106112
mock_nd_to_nz_2d.return_value = mock_nz_tensor
107113
mock_npu_format_cast.return_value = mock_nz_tensor
108114

109-
self.builder.build(
110-
num_reqs,
111-
num_actual_tokens,
112-
max_query_len,
113-
)
115+
self.builder.build(common_attn_metadata, mock_model)
114116

115117
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
116118
@patch('torch_npu.npu_format_cast')
@@ -120,51 +122,53 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
120122
def test_build_chunked_prefill(self, mock_ascend_attention_state,
121123
mock_is_310p, mock_nd_to_nz_spec,
122124
mock_npu_format_cast, mock_ascend_metadata):
123-
num_reqs = 3
124-
num_actual_tokens = 15
125-
max_query_len = 6
126-
127-
self.mock_runner.input_batch.block_table = [MagicMock()]
128-
self.mock_runner.input_batch.block_table[
129-
0].get_device_tensor.return_value = torch.zeros((10, 10))
130-
self.mock_runner.max_num_blocks_per_req = 10
131-
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
132-
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
133-
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
134-
self.mock_runner.device = 'cpu:0'
135-
self.mock_runner.attn_mask = torch.ones((15, 15))
136-
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
137-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
125+
common_attn_metadata = AscendCommonAttentionMetadata(
126+
query_start_loc=torch.tensor([0, 2, 5, 9]),
127+
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
128+
seq_lens_cpu=torch.tensor([4, 5, 6]),
129+
num_reqs=3,
130+
num_actual_tokens=15,
131+
max_query_len=6,
132+
decode_token_per_req=torch.tensor([1, 1, 1]),
133+
block_table_tensor=torch.zeros((10, 10)),
134+
slot_mapping_cpu=torch.tensor(range(20)),
135+
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
136+
positions=torch.tensor([10, 10]),
137+
attn_mask=torch.ones((15, 15)),
138+
spec_attn_mask=None,
139+
attn_state=AscendAttentionState.ChunkedPrefill)
138140

139141
mock_ascend_attention_state = MagicMock()
140142
mock_ascend_attention_state.PrefillNoCache = 0
141143

142144
mock_nz_tensor = MagicMock()
145+
mock_model = MagicMock()
143146
mock_nd_to_nz_spec.return_value = mock_nz_tensor
144147
mock_npu_format_cast.return_value = mock_nz_tensor
145148

146-
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
149+
self.builder.build(common_attn_metadata, mock_model)
147150

148151
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
149152
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
150153
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
151-
num_reqs = 3
152-
num_actual_tokens = 15
153-
max_query_len = 6
154-
155-
self.mock_runner.input_batch.block_table = [MagicMock()]
156-
self.mock_runner.input_batch.block_table[
157-
0].get_device_tensor.return_value = torch.zeros((10, 10))
158-
self.mock_runner.max_num_blocks_per_req = 10
159-
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
160-
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
161-
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
162-
self.mock_runner.device = 'cpu:0'
163-
self.mock_runner.attn_mask = torch.ones((15, 15))
164-
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
165-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
166-
167-
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
154+
common_attn_metadata = AscendCommonAttentionMetadata(
155+
query_start_loc=torch.tensor([0, 2, 5, 9]),
156+
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
157+
seq_lens_cpu=torch.tensor([4, 5, 6]),
158+
num_reqs=3,
159+
num_actual_tokens=15,
160+
max_query_len=6,
161+
decode_token_per_req=torch.tensor([1, 1, 1]),
162+
block_table_tensor=torch.zeros((10, 10)),
163+
slot_mapping_cpu=torch.tensor(range(20)),
164+
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
165+
positions=torch.tensor([10, 10]),
166+
attn_mask=torch.ones((15, 15)),
167+
spec_attn_mask=None,
168+
attn_state=AscendAttentionState.ChunkedPrefill)
169+
mock_model = MagicMock()
170+
171+
self.builder.build(common_attn_metadata, mock_model)
168172

169173

170174
class TestAscendAttentionBackendImpl(TestBase):

tests/ut/attention/test_mla_v1.py

Lines changed: 76 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from unittest.mock import MagicMock, patch
22

3-
import numpy as np
43
import torch
54
from vllm.distributed.parallel_state import GroupCoordinator
65
from vllm.model_executor.layers.linear import LinearBase
@@ -12,6 +11,7 @@
1211
AscendMLAImpl, AscendMLAMetadata,
1312
AscendMLAMetadataBuilder,
1413
AscendMLAPrefillMetadata)
14+
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
1515

1616

1717
class TestAscendMLABackend(TestBase):
@@ -178,40 +178,41 @@ def test_ascend_mla_metadata_default(self):
178178
class TestAscendMLAMetadataBuilder(TestBase):
179179

180180
def test_ascend_mla_metadata_builder_default(self):
181-
runner = MagicMock()
182-
runner.scheduler_config = MagicMock()
183-
runner.model_config = MagicMock()
184-
runner.scheduler_config.max_num_seqs = 4
185-
runner.model_config.max_model_len = 1024
186-
runner.model_config.get_head_size.return_value = 64
187-
runner.model_config.dtype = torch.float16
188-
runner.chunked_prefill_enabled = False
189-
runner.device = "cpu"
190-
runner.block_size = 16
191-
runner.decode_token_per_req = 1
181+
mock_vllm_config = MagicMock()
182+
mock_vllm_config.model_config.max_model_len = 1024
183+
mock_vllm_config.model_config.get_head_size.return_value = 64
184+
mock_vllm_config.model_config.dtype = torch.float16
185+
mock_vllm_config.cache_config.block_size = 16
186+
mock_vllm_config.scheduler_config.max_num_seqs = 4
187+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
188+
mock_device = 'cpu'
192189

193190
ascend_config = MagicMock()
194191
ascend_config.torchair_graph_config = MagicMock()
195192
ascend_config.torchair_graph_config.enabled = True
196193
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
197194
return_value=ascend_config):
198-
builder = AscendMLAMetadataBuilder(runner)
195+
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
199196

200-
self.assertEqual(builder.runner, runner)
201-
self.assertEqual(builder.block_size, runner.block_size)
202-
self.assertEqual(builder.chunked_prefill_enabled,
203-
runner.chunked_prefill_enabled)
197+
self.assertEqual(builder.block_size,
198+
mock_vllm_config.cache_config.block_size)
199+
self.assertEqual(
200+
builder.chunked_prefill_enabled,
201+
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
204202
self.assertEqual(builder.torchair_graph_enabled, True)
205203

206204
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
207205
def test_reorder_batch_with_torchair_graph(self, ascend_config):
208-
runner = MagicMock()
209-
runner.chunked_prefill_enabled = False
210-
runner.decode_token_per_req = 1
206+
mock_vllm_config = MagicMock()
207+
mock_vllm_config.model_config.max_model_len = 1024
208+
mock_vllm_config.cache_config.block_size = 16
209+
mock_vllm_config.scheduler_config.max_num_seqs = 4
210+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
211+
mock_device = 'cpu'
211212
ascend_config.torchair_graph_config = MagicMock()
212213
ascend_config.torchair_graph_config.enabled = True
213214

214-
builder = AscendMLAMetadataBuilder(runner)
215+
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
215216

216217
input_batch = MagicMock()
217218
input_batch.req_ids = [0, 1, 2, 3]
@@ -230,22 +231,23 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config):
230231
modified = builder.reorder_batch(input_batch, scheduler_output)
231232

232233
self.assertFalse(modified)
233-
self.assertEqual(builder._num_decodes, 4)
234-
self.assertEqual(builder._num_prefills, 0)
235-
self.assertEqual(builder._num_decode_tokens, 7)
236-
self.assertEqual(builder._num_prefill_tokens, 0)
237234
input_batch.swap_states.assert_not_called()
238235

239236
def test_reorder_batch_without_torchair_graph(self):
240237
ascend_config = MagicMock()
241-
runner = MagicMock()
242-
runner.chunked_prefill_enabled = False
243-
runner.decode_token_per_req = 1
244238
ascend_config.torchair_graph_config = MagicMock()
245239
ascend_config.torchair_graph_config.enabled = False
240+
241+
mock_vllm_config = MagicMock()
242+
mock_vllm_config.model_config.max_model_len = 1024
243+
mock_vllm_config.cache_config.block_size = 16
244+
mock_vllm_config.scheduler_config.max_num_seqs = 4
245+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
246+
mock_device = 'cpu'
247+
246248
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
247249
return_value=ascend_config):
248-
builder = AscendMLAMetadataBuilder(runner)
250+
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
249251

250252
input_batch = MagicMock()
251253
input_batch.req_ids = [0, 1, 2, 3]
@@ -264,22 +266,20 @@ def test_reorder_batch_without_torchair_graph(self):
264266
modified = builder.reorder_batch(input_batch, scheduler_output)
265267

266268
self.assertTrue(modified)
267-
self.assertEqual(builder._num_decodes, 2)
268-
self.assertEqual(builder._num_prefills, 2)
269-
self.assertEqual(builder._num_decode_tokens, 2)
270-
self.assertEqual(builder._num_prefill_tokens, 5)
271269
input_batch.swap_states.assert_called_once_with(1, 2)
272270

273271
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
274272
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
275273
ascend_config = MagicMock()
276274
mock_ascend_config.return_value = ascend_config
277275
ascend_config.torchair_graph_config.enabled = False
278-
runner = MagicMock()
279-
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
280-
runner.chunked_prefill_enabled = False
281-
runner.decode_token_per_req = 1
282-
builder = AscendMLAMetadataBuilder(runner=runner)
276+
mock_vllm_config = MagicMock()
277+
mock_vllm_config.model_config.max_model_len = 1024
278+
mock_vllm_config.cache_config.block_size = 16
279+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
280+
mock_device = 'cpu'
281+
282+
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
283283
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
284284

285285
result = builder._get_graph_runner_block_tables(3, block_tables)
@@ -292,11 +292,13 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
292292
ascend_config = MagicMock()
293293
mock_ascend_config.return_value = ascend_config
294294
ascend_config.torchair_graph_config.enabled = False
295-
runner = MagicMock()
296-
runner.graph_block_tables = torch.zeros((8, 4), dtype=torch.int32)
297-
runner.chunked_prefill_enabled = False
298-
runner.decode_token_per_req = 1
299-
builder = AscendMLAMetadataBuilder(runner=runner)
295+
mock_vllm_config = MagicMock()
296+
mock_vllm_config.model_config.max_model_len = 64
297+
mock_vllm_config.cache_config.block_size = 16
298+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
299+
mock_device = 'cpu'
300+
301+
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
300302
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
301303

302304
result = builder._get_graph_runner_block_tables(3, block_tables)
@@ -310,11 +312,13 @@ def test_get_graph_runner_block_tables_from_numpy(self,
310312
ascend_config = MagicMock()
311313
mock_ascend_config.return_value = ascend_config
312314
ascend_config.torchair_graph_config.enabled = False
313-
runner = MagicMock()
314-
runner.graph_block_tables = np.zeros((8, 64), dtype=np.int32)
315-
runner.chunked_prefill_enabled = False
316-
runner.decode_token_per_req = 1
317-
builder = AscendMLAMetadataBuilder(runner=runner)
315+
mock_vllm_config = MagicMock()
316+
mock_vllm_config.model_config.max_model_len = 1024
317+
mock_vllm_config.cache_config.block_size = 16
318+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
319+
mock_device = 'cpu'
320+
321+
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
318322

319323
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
320324

@@ -329,38 +333,45 @@ def test_build_dummy(self, mock_ascend_config):
329333
ascend_config = MagicMock()
330334
mock_ascend_config.return_value = ascend_config
331335
ascend_config.torchair_graph_config.enabled = False
332-
runner = MagicMock()
333-
runner.model_config = MagicMock()
334-
runner.device = "cpu"
335-
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
336-
runner.model_config.get_head_size.return_value = 64
337-
runner.chunked_prefill_enabled = False
338-
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
339-
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
340-
runner.dtype = torch.float16
341-
runner.decode_token_per_req = 1
342-
343-
builder = AscendMLAMetadataBuilder(runner=runner,
336+
337+
mock_vllm_config = MagicMock()
338+
mock_vllm_config.model_config.max_model_len = 1024
339+
mock_vllm_config.cache_config.block_size = 16
340+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
341+
mock_vllm_config.get_head_size.return_value = 64
342+
mock_vllm_config.model_config.dtype = torch.float16
343+
mock_device = 'cpu'
344+
345+
builder = AscendMLAMetadataBuilder(mock_vllm_config,
346+
mock_device,
344347
metadata_cls=AscendMLAMetadata)
345348
builder.rope_dim = 64
346349

347350
with patch.object(builder,
348351
"_get_graph_runner_block_tables",
349352
side_effect=lambda x, y: y):
350-
metadata = builder.build_torchair_graph_dummy(3, 3)
353+
common_attn_metadata = TorchairCommonAttentionMetadata(
354+
num_reqs=3,
355+
num_actual_tokens=3,
356+
decode_token_per_req=1,
357+
actual_seq_lengths_q=[0, 1, 2],
358+
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
359+
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
360+
)
361+
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
351362

352363
sin_golden = torch.ones(3,
353364
1,
354365
1,
355366
64,
356-
dtype=runner.dtype,
357-
device=runner.device)
367+
dtype=torch.float16,
368+
device=mock_device)
358369
cos_golden = torch.ones(3,
359370
1,
360371
1,
361372
64,
362-
dtype=runner.dtype,
363-
device=runner.device)
373+
dtype=torch.float16,
374+
device=mock_device)
364375

365376
self.assertIsInstance(metadata, AscendMLAMetadata)
366377
self.assertEqual(metadata.num_input_tokens, 3)

0 commit comments

Comments
 (0)