Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 59 additions & 55 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AscendAttentionState,
AscendMetadata,
CommonAttentionState)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata


class TestAscendAttentionBackend(TestBase):
Expand Down Expand Up @@ -67,8 +68,12 @@ def test_copy_blocks(self):
class TestAscendAttentionMetadataBuilder(TestBase):

def setUp(self):
self.mock_runner = MagicMock()
self.builder = AscendAttentionMetadataBuilder(self.mock_runner)
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
self.mock_device)

def test_reorder_batch(self):
mock_input_batch = MagicMock()
Expand All @@ -86,31 +91,28 @@ def test_reorder_batch(self):
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
mock_npu_format_cast,
mock_ascend_metadata):
num_reqs = 2
num_actual_tokens = 10
max_query_len = 5

self.mock_runner.input_batch.block_table = [MagicMock()]
self.mock_runner.input_batch.block_table[
0].get_device_tensor.return_value = torch.zeros((10, 10))
self.mock_runner.max_num_blocks_per_req = 10
self.mock_runner.query_lens = torch.tensor([3, 4])
self.mock_runner.seq_lens_cpu = torch.tensor([5, 6])
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
self.mock_runner.device = 'cpu:0'
self.mock_runner.attn_mask = torch.ones((10, 10))
self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache)

mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_2d.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor

self.builder.build(
num_reqs,
num_actual_tokens,
max_query_len,
)
self.builder.build(common_attn_metadata, mock_model)

@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
Expand All @@ -120,51 +122,53 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
def test_build_chunked_prefill(self, mock_ascend_attention_state,
mock_is_310p, mock_nd_to_nz_spec,
mock_npu_format_cast, mock_ascend_metadata):
num_reqs = 3
num_actual_tokens = 15
max_query_len = 6

self.mock_runner.input_batch.block_table = [MagicMock()]
self.mock_runner.input_batch.block_table[
0].get_device_tensor.return_value = torch.zeros((10, 10))
self.mock_runner.max_num_blocks_per_req = 10
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
self.mock_runner.device = 'cpu:0'
self.mock_runner.attn_mask = torch.ones((15, 15))
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)

mock_ascend_attention_state = MagicMock()
mock_ascend_attention_state.PrefillNoCache = 0

mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_spec.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor

self.builder.build(num_reqs, num_actual_tokens, max_query_len)
self.builder.build(common_attn_metadata, mock_model)

@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
num_reqs = 3
num_actual_tokens = 15
max_query_len = 6

self.mock_runner.input_batch.block_table = [MagicMock()]
self.mock_runner.input_batch.block_table[
0].get_device_tensor.return_value = torch.zeros((10, 10))
self.mock_runner.max_num_blocks_per_req = 10
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
self.mock_runner.device = 'cpu:0'
self.mock_runner.attn_mask = torch.ones((15, 15))
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])

self.builder.build(num_reqs, num_actual_tokens, max_query_len)
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_model = MagicMock()

self.builder.build(common_attn_metadata, mock_model)


class TestAscendAttentionBackendImpl(TestBase):
Expand Down
141 changes: 76 additions & 65 deletions tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from unittest.mock import MagicMock, patch

import numpy as np
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
Expand All @@ -12,6 +11,7 @@
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata


class TestAscendMLABackend(TestBase):
Expand Down Expand Up @@ -178,40 +178,41 @@ def test_ascend_mla_metadata_default(self):
class TestAscendMLAMetadataBuilder(TestBase):

def test_ascend_mla_metadata_builder_default(self):
runner = MagicMock()
runner.scheduler_config = MagicMock()
runner.model_config = MagicMock()
runner.scheduler_config.max_num_seqs = 4
runner.model_config.max_model_len = 1024
runner.model_config.get_head_size.return_value = 64
runner.model_config.dtype = torch.float16
runner.chunked_prefill_enabled = False
runner.device = "cpu"
runner.block_size = 16
runner.decode_token_per_req = 1
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(runner)
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)

self.assertEqual(builder.runner, runner)
self.assertEqual(builder.block_size, runner.block_size)
self.assertEqual(builder.chunked_prefill_enabled,
runner.chunked_prefill_enabled)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
self.assertEqual(builder.torchair_graph_enabled, True)

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
runner = MagicMock()
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True

builder = AscendMLAMetadataBuilder(runner)
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)

input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
Expand All @@ -230,22 +231,23 @@ def test_reorder_batch_with_torchair_graph(self, ascend_config):
modified = builder.reorder_batch(input_batch, scheduler_output)

self.assertFalse(modified)
self.assertEqual(builder._num_decodes, 4)
self.assertEqual(builder._num_prefills, 0)
self.assertEqual(builder._num_decode_tokens, 7)
self.assertEqual(builder._num_prefill_tokens, 0)
input_batch.swap_states.assert_not_called()

def test_reorder_batch_without_torchair_graph(self):
ascend_config = MagicMock()
runner = MagicMock()
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = False

mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(runner)
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)

input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
Expand All @@ -264,22 +266,20 @@ def test_reorder_batch_without_torchair_graph(self):
modified = builder.reorder_batch(input_batch, scheduler_output)

self.assertTrue(modified)
self.assertEqual(builder._num_decodes, 2)
self.assertEqual(builder._num_prefills, 2)
self.assertEqual(builder._num_decode_tokens, 2)
self.assertEqual(builder._num_prefill_tokens, 5)
input_batch.swap_states.assert_called_once_with(1, 2)

@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

result = builder._get_graph_runner_block_tables(3, block_tables)
Expand All @@ -292,11 +292,13 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.graph_block_tables = torch.zeros((8, 4), dtype=torch.int32)
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

result = builder._get_graph_runner_block_tables(3, block_tables)
Expand All @@ -310,11 +312,13 @@ def test_get_graph_runner_block_tables_from_numpy(self,
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.graph_block_tables = np.zeros((8, 64), dtype=np.int32)
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)

block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)

Expand All @@ -329,38 +333,45 @@ def test_build_dummy(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.model_config = MagicMock()
runner.device = "cpu"
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
runner.model_config.get_head_size.return_value = 64
runner.chunked_prefill_enabled = False
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.dtype = torch.float16
runner.decode_token_per_req = 1

builder = AscendMLAMetadataBuilder(runner=runner,

mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_device = 'cpu'

builder = AscendMLAMetadataBuilder(mock_vllm_config,
mock_device,
metadata_cls=AscendMLAMetadata)
builder.rope_dim = 64

with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
metadata = builder.build_torchair_graph_dummy(3, 3)
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=3,
num_actual_tokens=3,
decode_token_per_req=1,
actual_seq_lengths_q=[0, 1, 2],
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
)
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)

sin_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
dtype=torch.float16,
device=mock_device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
dtype=torch.float16,
device=mock_device)

self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
Expand Down
Loading
Loading