1
1
from unittest .mock import MagicMock , patch
2
2
3
- import numpy as np
4
3
import torch
5
4
from vllm .distributed .parallel_state import GroupCoordinator
6
5
from vllm .model_executor .layers .linear import LinearBase
12
11
AscendMLAImpl , AscendMLAMetadata ,
13
12
AscendMLAMetadataBuilder ,
14
13
AscendMLAPrefillMetadata )
14
+ from vllm_ascend .torchair .utils import TorchairCommonAttentionMetadata
15
15
16
16
17
17
class TestAscendMLABackend (TestBase ):
@@ -178,40 +178,41 @@ def test_ascend_mla_metadata_default(self):
178
178
class TestAscendMLAMetadataBuilder (TestBase ):
179
179
180
180
def test_ascend_mla_metadata_builder_default (self ):
181
- runner = MagicMock ()
182
- runner .scheduler_config = MagicMock ()
183
- runner .model_config = MagicMock ()
184
- runner .scheduler_config .max_num_seqs = 4
185
- runner .model_config .max_model_len = 1024
186
- runner .model_config .get_head_size .return_value = 64
187
- runner .model_config .dtype = torch .float16
188
- runner .chunked_prefill_enabled = False
189
- runner .device = "cpu"
190
- runner .block_size = 16
191
- runner .decode_token_per_req = 1
181
+ mock_vllm_config = MagicMock ()
182
+ mock_vllm_config .model_config .max_model_len = 1024
183
+ mock_vllm_config .model_config .get_head_size .return_value = 64
184
+ mock_vllm_config .model_config .dtype = torch .float16
185
+ mock_vllm_config .cache_config .block_size = 16
186
+ mock_vllm_config .scheduler_config .max_num_seqs = 4
187
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
188
+ mock_device = 'cpu'
192
189
193
190
ascend_config = MagicMock ()
194
191
ascend_config .torchair_graph_config = MagicMock ()
195
192
ascend_config .torchair_graph_config .enabled = True
196
193
with patch ("vllm_ascend.attention.mla_v1.get_ascend_config" ,
197
194
return_value = ascend_config ):
198
- builder = AscendMLAMetadataBuilder (runner )
195
+ builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
199
196
200
- self .assertEqual (builder .runner , runner )
201
- self .assertEqual (builder .block_size , runner .block_size )
202
- self .assertEqual (builder .chunked_prefill_enabled ,
203
- runner .chunked_prefill_enabled )
197
+ self .assertEqual (builder .block_size ,
198
+ mock_vllm_config .cache_config .block_size )
199
+ self .assertEqual (
200
+ builder .chunked_prefill_enabled ,
201
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled )
204
202
self .assertEqual (builder .torchair_graph_enabled , True )
205
203
206
204
@patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
207
205
def test_reorder_batch_with_torchair_graph (self , ascend_config ):
208
- runner = MagicMock ()
209
- runner .chunked_prefill_enabled = False
210
- runner .decode_token_per_req = 1
206
+ mock_vllm_config = MagicMock ()
207
+ mock_vllm_config .model_config .max_model_len = 1024
208
+ mock_vllm_config .cache_config .block_size = 16
209
+ mock_vllm_config .scheduler_config .max_num_seqs = 4
210
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
211
+ mock_device = 'cpu'
211
212
ascend_config .torchair_graph_config = MagicMock ()
212
213
ascend_config .torchair_graph_config .enabled = True
213
214
214
- builder = AscendMLAMetadataBuilder (runner )
215
+ builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
215
216
216
217
input_batch = MagicMock ()
217
218
input_batch .req_ids = [0 , 1 , 2 , 3 ]
@@ -230,22 +231,23 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config):
230
231
modified = builder .reorder_batch (input_batch , scheduler_output )
231
232
232
233
self .assertFalse (modified )
233
- self .assertEqual (builder ._num_decodes , 4 )
234
- self .assertEqual (builder ._num_prefills , 0 )
235
- self .assertEqual (builder ._num_decode_tokens , 7 )
236
- self .assertEqual (builder ._num_prefill_tokens , 0 )
237
234
input_batch .swap_states .assert_not_called ()
238
235
239
236
def test_reorder_batch_without_torchair_graph (self ):
240
237
ascend_config = MagicMock ()
241
- runner = MagicMock ()
242
- runner .chunked_prefill_enabled = False
243
- runner .decode_token_per_req = 1
244
238
ascend_config .torchair_graph_config = MagicMock ()
245
239
ascend_config .torchair_graph_config .enabled = False
240
+
241
+ mock_vllm_config = MagicMock ()
242
+ mock_vllm_config .model_config .max_model_len = 1024
243
+ mock_vllm_config .cache_config .block_size = 16
244
+ mock_vllm_config .scheduler_config .max_num_seqs = 4
245
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
246
+ mock_device = 'cpu'
247
+
246
248
with patch ("vllm_ascend.attention.mla_v1.get_ascend_config" ,
247
249
return_value = ascend_config ):
248
- builder = AscendMLAMetadataBuilder (runner )
250
+ builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
249
251
250
252
input_batch = MagicMock ()
251
253
input_batch .req_ids = [0 , 1 , 2 , 3 ]
@@ -264,22 +266,20 @@ def test_reorder_batch_without_torchair_graph(self):
264
266
modified = builder .reorder_batch (input_batch , scheduler_output )
265
267
266
268
self .assertTrue (modified )
267
- self .assertEqual (builder ._num_decodes , 2 )
268
- self .assertEqual (builder ._num_prefills , 2 )
269
- self .assertEqual (builder ._num_decode_tokens , 2 )
270
- self .assertEqual (builder ._num_prefill_tokens , 5 )
271
269
input_batch .swap_states .assert_called_once_with (1 , 2 )
272
270
273
271
@patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
274
272
def test_get_graph_runner_block_tables_normal (self , mock_ascend_config ):
275
273
ascend_config = MagicMock ()
276
274
mock_ascend_config .return_value = ascend_config
277
275
ascend_config .torchair_graph_config .enabled = False
278
- runner = MagicMock ()
279
- runner .graph_block_tables = torch .zeros ((8 , 64 ), dtype = torch .int32 )
280
- runner .chunked_prefill_enabled = False
281
- runner .decode_token_per_req = 1
282
- builder = AscendMLAMetadataBuilder (runner = runner )
276
+ mock_vllm_config = MagicMock ()
277
+ mock_vllm_config .model_config .max_model_len = 1024
278
+ mock_vllm_config .cache_config .block_size = 16
279
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
280
+ mock_device = 'cpu'
281
+
282
+ builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
283
283
block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
284
284
285
285
result = builder ._get_graph_runner_block_tables (3 , block_tables )
@@ -292,11 +292,13 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
292
292
ascend_config = MagicMock ()
293
293
mock_ascend_config .return_value = ascend_config
294
294
ascend_config .torchair_graph_config .enabled = False
295
- runner = MagicMock ()
296
- runner .graph_block_tables = torch .zeros ((8 , 4 ), dtype = torch .int32 )
297
- runner .chunked_prefill_enabled = False
298
- runner .decode_token_per_req = 1
299
- builder = AscendMLAMetadataBuilder (runner = runner )
295
+ mock_vllm_config = MagicMock ()
296
+ mock_vllm_config .model_config .max_model_len = 64
297
+ mock_vllm_config .cache_config .block_size = 16
298
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
299
+ mock_device = 'cpu'
300
+
301
+ builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
300
302
block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
301
303
302
304
result = builder ._get_graph_runner_block_tables (3 , block_tables )
@@ -310,11 +312,13 @@ def test_get_graph_runner_block_tables_from_numpy(self,
310
312
ascend_config = MagicMock ()
311
313
mock_ascend_config .return_value = ascend_config
312
314
ascend_config .torchair_graph_config .enabled = False
313
- runner = MagicMock ()
314
- runner .graph_block_tables = np .zeros ((8 , 64 ), dtype = np .int32 )
315
- runner .chunked_prefill_enabled = False
316
- runner .decode_token_per_req = 1
317
- builder = AscendMLAMetadataBuilder (runner = runner )
315
+ mock_vllm_config = MagicMock ()
316
+ mock_vllm_config .model_config .max_model_len = 1024
317
+ mock_vllm_config .cache_config .block_size = 16
318
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
319
+ mock_device = 'cpu'
320
+
321
+ builder = AscendMLAMetadataBuilder (mock_vllm_config , mock_device )
318
322
319
323
block_tables = torch .randint (0 , 100 , (3 , 10 ), dtype = torch .int32 )
320
324
@@ -329,38 +333,45 @@ def test_build_dummy(self, mock_ascend_config):
329
333
ascend_config = MagicMock ()
330
334
mock_ascend_config .return_value = ascend_config
331
335
ascend_config .torchair_graph_config .enabled = False
332
- runner = MagicMock ()
333
- runner .model_config = MagicMock ()
334
- runner .device = "cpu"
335
- runner .graph_block_tables = torch .zeros ((8 , 64 ), dtype = torch .int32 )
336
- runner .model_config .get_head_size .return_value = 64
337
- runner .chunked_prefill_enabled = False
338
- runner .attn_mask = torch .zeros ((1 , 1 ), dtype = torch .bool )
339
- runner .spec_attn_mask = torch .zeros ((1 , 1 ), dtype = torch .bool )
340
- runner .dtype = torch .float16
341
- runner .decode_token_per_req = 1
342
-
343
- builder = AscendMLAMetadataBuilder (runner = runner ,
336
+
337
+ mock_vllm_config = MagicMock ()
338
+ mock_vllm_config .model_config .max_model_len = 1024
339
+ mock_vllm_config .cache_config .block_size = 16
340
+ mock_vllm_config .scheduler_config .chunked_prefill_enabled = False
341
+ mock_vllm_config .get_head_size .return_value = 64
342
+ mock_vllm_config .model_config .dtype = torch .float16
343
+ mock_device = 'cpu'
344
+
345
+ builder = AscendMLAMetadataBuilder (mock_vllm_config ,
346
+ mock_device ,
344
347
metadata_cls = AscendMLAMetadata )
345
348
builder .rope_dim = 64
346
349
347
350
with patch .object (builder ,
348
351
"_get_graph_runner_block_tables" ,
349
352
side_effect = lambda x , y : y ):
350
- metadata = builder .build_torchair_graph_dummy (3 , 3 )
353
+ common_attn_metadata = TorchairCommonAttentionMetadata (
354
+ num_reqs = 3 ,
355
+ num_actual_tokens = 3 ,
356
+ decode_token_per_req = 1 ,
357
+ actual_seq_lengths_q = [0 , 1 , 2 ],
358
+ attn_mask = torch .zeros ((1 , 1 ), dtype = torch .bool ),
359
+ spec_attn_mask = torch .zeros ((1 , 1 ), dtype = torch .bool ),
360
+ )
361
+ metadata = builder .build_torchair_graph_dummy (common_attn_metadata )
351
362
352
363
sin_golden = torch .ones (3 ,
353
364
1 ,
354
365
1 ,
355
366
64 ,
356
- dtype = runner . dtype ,
357
- device = runner . device )
367
+ dtype = torch . float16 ,
368
+ device = mock_device )
358
369
cos_golden = torch .ones (3 ,
359
370
1 ,
360
371
1 ,
361
372
64 ,
362
- dtype = runner . dtype ,
363
- device = runner . device )
373
+ dtype = torch . float16 ,
374
+ device = mock_device )
364
375
365
376
self .assertIsInstance (metadata , AscendMLAMetadata )
366
377
self .assertEqual (metadata .num_input_tokens , 3 )
0 commit comments