9
9
AscendAttentionState ,
10
10
AscendMetadata ,
11
11
CommonAttentionState )
12
+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
12
13
13
14
14
15
class TestAscendAttentionBackend (TestBase ):
@@ -67,8 +68,11 @@ def test_copy_blocks(self):
67
68
class TestAscendAttentionMetadataBuilder (TestBase ):
68
69
69
70
def setUp (self ):
70
- self .mock_runner = MagicMock ()
71
- self .builder = AscendAttentionMetadataBuilder (self .mock_runner )
71
+ self .mock_vllm_config = MagicMock ()
72
+ self .mock_vllm_config .model_config .max_model_len = 640
73
+ self .mock_vllm_config .cache_config .block_size = 64
74
+ self .mock_device = 'cpu:0'
75
+ self .builder = AscendAttentionMetadataBuilder (self .mock_vllm_config , self .mock_device )
72
76
73
77
def test_reorder_batch (self ):
74
78
mock_input_batch = MagicMock ()
@@ -86,30 +90,31 @@ def test_reorder_batch(self):
86
90
def test_build_prefill_no_cache (self , mock_is_310p , mock_nd_to_nz_2d ,
87
91
mock_npu_format_cast ,
88
92
mock_ascend_metadata ):
89
- num_reqs = 2
90
- num_actual_tokens = 10
91
- max_query_len = 5
92
-
93
- self .mock_runner .input_batch .block_table = [MagicMock ()]
94
- self .mock_runner .input_batch .block_table [
95
- 0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
96
- self .mock_runner .max_num_blocks_per_req = 10
97
- self .mock_runner .query_lens = torch .tensor ([3 , 4 ])
98
- self .mock_runner .seq_lens_cpu = torch .tensor ([5 , 6 ])
99
- self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
100
- self .mock_runner .device = 'cpu:0'
101
- self .mock_runner .attn_mask = torch .ones ((10 , 10 ))
102
- self .mock_runner .attn_state = AscendAttentionState .PrefillNoCache
103
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 3 , 7 ])
93
+ common_attn_metadata = AscendCommonAttentionMetadata (
94
+ query_start_loc = torch .tensor ([0 , 3 , 7 ]),
95
+ query_start_loc_cpu = torch .tensor ([0 , 3 , 7 ]),
96
+ seq_lens_cpu = torch .tensor ([5 , 6 ]),
97
+ num_reqs = 2 ,
98
+ num_actual_tokens = 10 ,
99
+ max_query_len = 5 ,
100
+ decode_token_per_req = torch .tensor ([1 , 1 ]),
101
+ block_table_tensor = torch .zeros ((10 , 10 )),
102
+ slot_mapping_cpu = torch .tensor (range (20 )),
103
+ actual_seq_lengths_q = torch .tensor ([0 , 1 ]),
104
+ positions = torch .tensor ([10 , 10 ]),
105
+ attn_mask = torch .ones ((10 , 10 )),
106
+ spec_attn_mask = None ,
107
+ attn_state = AscendAttentionState .PrefillNoCache
108
+ )
104
109
105
110
mock_nz_tensor = MagicMock ()
111
+ mock_model = MagicMock ()
106
112
mock_nd_to_nz_2d .return_value = mock_nz_tensor
107
113
mock_npu_format_cast .return_value = mock_nz_tensor
108
114
109
115
self .builder .build (
110
- num_reqs ,
111
- num_actual_tokens ,
112
- max_query_len ,
116
+ common_attn_metadata ,
117
+ mock_model
113
118
)
114
119
115
120
@patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
@@ -120,51 +125,55 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
120
125
def test_build_chunked_prefill (self , mock_ascend_attention_state ,
121
126
mock_is_310p , mock_nd_to_nz_spec ,
122
127
mock_npu_format_cast , mock_ascend_metadata ):
123
- num_reqs = 3
124
- num_actual_tokens = 15
125
- max_query_len = 6
126
-
127
- self .mock_runner .input_batch .block_table = [MagicMock ()]
128
- self .mock_runner .input_batch .block_table [
129
- 0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
130
- self .mock_runner .max_num_blocks_per_req = 10
131
- self .mock_runner .query_lens = torch .tensor ([2 , 3 , 4 ])
132
- self .mock_runner .seq_lens_cpu = torch .tensor ([4 , 5 , 6 ])
133
- self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
134
- self .mock_runner .device = 'cpu:0'
135
- self .mock_runner .attn_mask = torch .ones ((15 , 15 ))
136
- self .mock_runner .attn_state = AscendAttentionState .ChunkedPrefill
137
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ])
128
+ common_attn_metadata = AscendCommonAttentionMetadata (
129
+ query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
130
+ query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
131
+ seq_lens_cpu = torch .tensor ([4 , 5 , 6 ]),
132
+ num_reqs = 3 ,
133
+ num_actual_tokens = 15 ,
134
+ max_query_len = 6 ,
135
+ decode_token_per_req = torch .tensor ([1 , 1 , 1 ]),
136
+ block_table_tensor = torch .zeros ((10 , 10 )),
137
+ slot_mapping_cpu = torch .tensor (range (20 )),
138
+ actual_seq_lengths_q = torch .tensor ([0 , 1 , 2 ]),
139
+ positions = torch .tensor ([10 , 10 ]),
140
+ attn_mask = torch .ones ((15 , 15 )),
141
+ spec_attn_mask = None ,
142
+ attn_state = AscendAttentionState .ChunkedPrefill
143
+ )
138
144
139
145
mock_ascend_attention_state = MagicMock ()
140
146
mock_ascend_attention_state .PrefillNoCache = 0
141
147
142
148
mock_nz_tensor = MagicMock ()
149
+ mock_model = MagicMock ()
143
150
mock_nd_to_nz_spec .return_value = mock_nz_tensor
144
151
mock_npu_format_cast .return_value = mock_nz_tensor
145
152
146
- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
153
+ self .builder .build (common_attn_metadata , mock_model )
147
154
148
155
@patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
149
156
@patch ('vllm_ascend.attention.attention_v1.is_310p' , return_value = False )
150
157
def test_build_non_310p (self , mock_is_310p , mock_ascend_metadata ):
151
- num_reqs = 3
152
- num_actual_tokens = 15
153
- max_query_len = 6
154
-
155
- self .mock_runner .input_batch .block_table = [MagicMock ()]
156
- self .mock_runner .input_batch .block_table [
157
- 0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
158
- self .mock_runner .max_num_blocks_per_req = 10
159
- self .mock_runner .query_lens = torch .tensor ([2 , 3 , 4 ])
160
- self .mock_runner .seq_lens_cpu = torch .tensor ([4 , 5 , 6 ])
161
- self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
162
- self .mock_runner .device = 'cpu:0'
163
- self .mock_runner .attn_mask = torch .ones ((15 , 15 ))
164
- self .mock_runner .attn_state = AscendAttentionState .ChunkedPrefill
165
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ])
166
-
167
- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
158
+ common_attn_metadata = AscendCommonAttentionMetadata (
159
+ query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
160
+ query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
161
+ seq_lens_cpu = torch .tensor ([4 , 5 , 6 ]),
162
+ num_reqs = 3 ,
163
+ num_actual_tokens = 15 ,
164
+ max_query_len = 6 ,
165
+ decode_token_per_req = torch .tensor ([1 , 1 , 1 ]),
166
+ block_table_tensor = torch .zeros ((10 , 10 )),
167
+ slot_mapping_cpu = torch .tensor (range (20 )),
168
+ actual_seq_lengths_q = torch .tensor ([0 , 1 , 2 ]),
169
+ positions = torch .tensor ([10 , 10 ]),
170
+ attn_mask = torch .ones ((15 , 15 )),
171
+ spec_attn_mask = None ,
172
+ attn_state = AscendAttentionState .ChunkedPrefill
173
+ )
174
+ mock_model = MagicMock ()
175
+
176
+ self .builder .build (common_attn_metadata , mock_model )
168
177
169
178
170
179
class TestAscendAttentionBackendImpl (TestBase ):
0 commit comments