Skip to content

Commit b2e121c

Browse files
committed
refact attn metadata build
Signed-off-by: weiguihua2 <[email protected]>
1 parent 0ad18c3 commit b2e121c

File tree

4 files changed

+182
-122
lines changed

4 files changed

+182
-122
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AscendAttentionState,
1010
AscendMetadata,
1111
CommonAttentionState)
12+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
1213

1314

1415
class TestAscendAttentionBackend(TestBase):
@@ -67,8 +68,11 @@ def test_copy_blocks(self):
6768
class TestAscendAttentionMetadataBuilder(TestBase):
6869

6970
def setUp(self):
70-
self.mock_runner = MagicMock()
71-
self.builder = AscendAttentionMetadataBuilder(self.mock_runner)
71+
self.mock_vllm_config = MagicMock()
72+
self.mock_vllm_config.model_config.max_model_len = 640
73+
self.mock_vllm_config.cache_config.block_size = 64
74+
self.mock_device = 'cpu:0'
75+
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config, self.mock_device)
7276

7377
def test_reorder_batch(self):
7478
mock_input_batch = MagicMock()
@@ -86,30 +90,31 @@ def test_reorder_batch(self):
8690
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
8791
mock_npu_format_cast,
8892
mock_ascend_metadata):
89-
num_reqs = 2
90-
num_actual_tokens = 10
91-
max_query_len = 5
92-
93-
self.mock_runner.input_batch.block_table = [MagicMock()]
94-
self.mock_runner.input_batch.block_table[
95-
0].get_device_tensor.return_value = torch.zeros((10, 10))
96-
self.mock_runner.max_num_blocks_per_req = 10
97-
self.mock_runner.query_lens = torch.tensor([3, 4])
98-
self.mock_runner.seq_lens_cpu = torch.tensor([5, 6])
99-
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
100-
self.mock_runner.device = 'cpu:0'
101-
self.mock_runner.attn_mask = torch.ones((10, 10))
102-
self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache
103-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7])
93+
common_attn_metadata = AscendCommonAttentionMetadata(
94+
query_start_loc=torch.tensor([0, 3, 7]),
95+
query_start_loc_cpu=torch.tensor([0, 3, 7]),
96+
seq_lens_cpu=torch.tensor([5, 6]),
97+
num_reqs=2,
98+
num_actual_tokens=10,
99+
max_query_len=5,
100+
decode_token_per_req=torch.tensor([1, 1]),
101+
block_table_tensor=torch.zeros((10, 10)),
102+
slot_mapping_cpu=torch.tensor(range(20)),
103+
actual_seq_lengths_q=torch.tensor([0, 1]),
104+
positions=torch.tensor([10, 10]),
105+
attn_mask=torch.ones((10, 10)),
106+
spec_attn_mask=None,
107+
attn_state=AscendAttentionState.PrefillNoCache
108+
)
104109

105110
mock_nz_tensor = MagicMock()
111+
mock_model = MagicMock()
106112
mock_nd_to_nz_2d.return_value = mock_nz_tensor
107113
mock_npu_format_cast.return_value = mock_nz_tensor
108114

109115
self.builder.build(
110-
num_reqs,
111-
num_actual_tokens,
112-
max_query_len,
116+
common_attn_metadata,
117+
mock_model
113118
)
114119

115120
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@@ -120,51 +125,55 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
120125
def test_build_chunked_prefill(self, mock_ascend_attention_state,
121126
mock_is_310p, mock_nd_to_nz_spec,
122127
mock_npu_format_cast, mock_ascend_metadata):
123-
num_reqs = 3
124-
num_actual_tokens = 15
125-
max_query_len = 6
126-
127-
self.mock_runner.input_batch.block_table = [MagicMock()]
128-
self.mock_runner.input_batch.block_table[
129-
0].get_device_tensor.return_value = torch.zeros((10, 10))
130-
self.mock_runner.max_num_blocks_per_req = 10
131-
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
132-
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
133-
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
134-
self.mock_runner.device = 'cpu:0'
135-
self.mock_runner.attn_mask = torch.ones((15, 15))
136-
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
137-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
128+
common_attn_metadata = AscendCommonAttentionMetadata(
129+
query_start_loc=torch.tensor([0, 2, 5, 9]),
130+
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
131+
seq_lens_cpu=torch.tensor([4, 5, 6]),
132+
num_reqs=3,
133+
num_actual_tokens=15,
134+
max_query_len=6,
135+
decode_token_per_req=torch.tensor([1, 1, 1]),
136+
block_table_tensor=torch.zeros((10, 10)),
137+
slot_mapping_cpu=torch.tensor(range(20)),
138+
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
139+
positions=torch.tensor([10, 10]),
140+
attn_mask=torch.ones((15, 15)),
141+
spec_attn_mask=None,
142+
attn_state=AscendAttentionState.ChunkedPrefill
143+
)
138144

139145
mock_ascend_attention_state = MagicMock()
140146
mock_ascend_attention_state.PrefillNoCache = 0
141147

142148
mock_nz_tensor = MagicMock()
149+
mock_model = MagicMock()
143150
mock_nd_to_nz_spec.return_value = mock_nz_tensor
144151
mock_npu_format_cast.return_value = mock_nz_tensor
145152

146-
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
153+
self.builder.build(common_attn_metadata, mock_model)
147154

148155
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
149156
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
150157
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
151-
num_reqs = 3
152-
num_actual_tokens = 15
153-
max_query_len = 6
154-
155-
self.mock_runner.input_batch.block_table = [MagicMock()]
156-
self.mock_runner.input_batch.block_table[
157-
0].get_device_tensor.return_value = torch.zeros((10, 10))
158-
self.mock_runner.max_num_blocks_per_req = 10
159-
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
160-
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
161-
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
162-
self.mock_runner.device = 'cpu:0'
163-
self.mock_runner.attn_mask = torch.ones((15, 15))
164-
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
165-
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
166-
167-
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
158+
common_attn_metadata = AscendCommonAttentionMetadata(
159+
query_start_loc=torch.tensor([0, 2, 5, 9]),
160+
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
161+
seq_lens_cpu=torch.tensor([4, 5, 6]),
162+
num_reqs=3,
163+
num_actual_tokens=15,
164+
max_query_len=6,
165+
decode_token_per_req=torch.tensor([1, 1, 1]),
166+
block_table_tensor=torch.zeros((10, 10)),
167+
slot_mapping_cpu=torch.tensor(range(20)),
168+
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
169+
positions=torch.tensor([10, 10]),
170+
attn_mask=torch.ones((15, 15)),
171+
spec_attn_mask=None,
172+
attn_state=AscendAttentionState.ChunkedPrefill
173+
)
174+
mock_model = MagicMock()
175+
176+
self.builder.build(common_attn_metadata, mock_model)
168177

169178

170179
class TestAscendAttentionBackendImpl(TestBase):

0 commit comments

Comments
 (0)