9
9
AscendAttentionState ,
10
10
AscendMetadata ,
11
11
CommonAttentionState )
12
+ from vllm_ascend .attention .utils import AscendCommonAttentionMetadata
12
13
13
14
14
15
class TestAscendAttentionBackend (TestBase ):
@@ -67,8 +68,12 @@ def test_copy_blocks(self):
67
68
class TestAscendAttentionMetadataBuilder (TestBase ):
68
69
69
70
def setUp (self ):
70
- self .mock_runner = MagicMock ()
71
- self .builder = AscendAttentionMetadataBuilder (self .mock_runner )
71
+ self .mock_vllm_config = MagicMock ()
72
+ self .mock_vllm_config .model_config .max_model_len = 640
73
+ self .mock_vllm_config .cache_config .block_size = 64
74
+ self .mock_device = 'cpu:0'
75
+ self .builder = AscendAttentionMetadataBuilder (self .mock_vllm_config ,
76
+ self .mock_device )
72
77
73
78
def test_reorder_batch (self ):
74
79
mock_input_batch = MagicMock ()
@@ -86,31 +91,28 @@ def test_reorder_batch(self):
86
91
def test_build_prefill_no_cache (self , mock_is_310p , mock_nd_to_nz_2d ,
87
92
mock_npu_format_cast ,
88
93
mock_ascend_metadata ):
89
- num_reqs = 2
90
- num_actual_tokens = 10
91
- max_query_len = 5
92
-
93
- self . mock_runner . input_batch . block_table = [ MagicMock ()]
94
- self . mock_runner . input_batch . block_table [
95
- 0 ]. get_device_tensor . return_value = torch . zeros (( 10 , 10 ))
96
- self . mock_runner . max_num_blocks_per_req = 10
97
- self . mock_runner . query_lens = torch .tensor ([ 3 , 4 ])
98
- self . mock_runner . seq_lens_cpu = torch .tensor ([ 5 , 6 ])
99
- self . mock_runner . slot_mapping_cpu = torch .tensor (range ( 20 ))
100
- self . mock_runner . device = 'cpu:0'
101
- self . mock_runner . attn_mask = torch .ones ((10 , 10 ))
102
- self . mock_runner . attn_state = AscendAttentionState . PrefillNoCache
103
- self . mock_runner . query_start_loc_cpu = torch . tensor ([ 0 , 3 , 7 ] )
94
+ common_attn_metadata = AscendCommonAttentionMetadata (
95
+ query_start_loc = torch . tensor ([ 0 , 3 , 7 ]),
96
+ query_start_loc_cpu = torch . tensor ([ 0 , 3 , 7 ]),
97
+ seq_lens_cpu = torch . tensor ([ 5 , 6 ]),
98
+ num_reqs = 2 ,
99
+ num_actual_tokens = 10 ,
100
+ max_query_len = 5 ,
101
+ decode_token_per_req = torch . tensor ([ 1 , 1 ]),
102
+ block_table_tensor = torch .zeros (( 10 , 10 )),
103
+ slot_mapping_cpu = torch .tensor (range ( 20 )),
104
+ actual_seq_lengths_q = torch .tensor ([ 0 , 1 ]),
105
+ positions = torch . tensor ([ 10 , 10 ]),
106
+ attn_mask = torch .ones ((10 , 10 )),
107
+ spec_attn_mask = None ,
108
+ attn_state = AscendAttentionState . PrefillNoCache )
104
109
105
110
mock_nz_tensor = MagicMock ()
111
+ mock_model = MagicMock ()
106
112
mock_nd_to_nz_2d .return_value = mock_nz_tensor
107
113
mock_npu_format_cast .return_value = mock_nz_tensor
108
114
109
- self .builder .build (
110
- num_reqs ,
111
- num_actual_tokens ,
112
- max_query_len ,
113
- )
115
+ self .builder .build (common_attn_metadata , mock_model )
114
116
115
117
@patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
116
118
@patch ('torch_npu.npu_format_cast' )
@@ -120,51 +122,53 @@ def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
120
122
def test_build_chunked_prefill (self , mock_ascend_attention_state ,
121
123
mock_is_310p , mock_nd_to_nz_spec ,
122
124
mock_npu_format_cast , mock_ascend_metadata ):
123
- num_reqs = 3
124
- num_actual_tokens = 15
125
- max_query_len = 6
126
-
127
- self . mock_runner . input_batch . block_table = [ MagicMock ()]
128
- self . mock_runner . input_batch . block_table [
129
- 0 ]. get_device_tensor . return_value = torch . zeros (( 10 , 10 ))
130
- self . mock_runner . max_num_blocks_per_req = 10
131
- self . mock_runner . query_lens = torch .tensor ([ 2 , 3 , 4 ])
132
- self . mock_runner . seq_lens_cpu = torch .tensor ([ 4 , 5 , 6 ])
133
- self . mock_runner . slot_mapping_cpu = torch .tensor (range ( 20 ))
134
- self . mock_runner . device = 'cpu:0'
135
- self . mock_runner . attn_mask = torch .ones ((15 , 15 ))
136
- self . mock_runner . attn_state = AscendAttentionState . ChunkedPrefill
137
- self . mock_runner . query_start_loc_cpu = torch . tensor ([ 0 , 2 , 5 , 9 ] )
125
+ common_attn_metadata = AscendCommonAttentionMetadata (
126
+ query_start_loc = torch . tensor ([ 0 , 2 , 5 , 9 ]),
127
+ query_start_loc_cpu = torch . tensor ([ 0 , 2 , 5 , 9 ]),
128
+ seq_lens_cpu = torch . tensor ([ 4 , 5 , 6 ]),
129
+ num_reqs = 3 ,
130
+ num_actual_tokens = 15 ,
131
+ max_query_len = 6 ,
132
+ decode_token_per_req = torch . tensor ([ 1 , 1 , 1 ]),
133
+ block_table_tensor = torch .zeros (( 10 , 10 )),
134
+ slot_mapping_cpu = torch .tensor (range ( 20 )),
135
+ actual_seq_lengths_q = torch .tensor ([ 0 , 1 , 2 ]),
136
+ positions = torch . tensor ([ 10 , 10 ]),
137
+ attn_mask = torch .ones ((15 , 15 )),
138
+ spec_attn_mask = None ,
139
+ attn_state = AscendAttentionState . ChunkedPrefill )
138
140
139
141
mock_ascend_attention_state = MagicMock ()
140
142
mock_ascend_attention_state .PrefillNoCache = 0
141
143
142
144
mock_nz_tensor = MagicMock ()
145
+ mock_model = MagicMock ()
143
146
mock_nd_to_nz_spec .return_value = mock_nz_tensor
144
147
mock_npu_format_cast .return_value = mock_nz_tensor
145
148
146
- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
149
+ self .builder .build (common_attn_metadata , mock_model )
147
150
148
151
@patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
149
152
@patch ('vllm_ascend.attention.attention_v1.is_310p' , return_value = False )
150
153
def test_build_non_310p (self , mock_is_310p , mock_ascend_metadata ):
151
- num_reqs = 3
152
- num_actual_tokens = 15
153
- max_query_len = 6
154
-
155
- self .mock_runner .input_batch .block_table = [MagicMock ()]
156
- self .mock_runner .input_batch .block_table [
157
- 0 ].get_device_tensor .return_value = torch .zeros ((10 , 10 ))
158
- self .mock_runner .max_num_blocks_per_req = 10
159
- self .mock_runner .query_lens = torch .tensor ([2 , 3 , 4 ])
160
- self .mock_runner .seq_lens_cpu = torch .tensor ([4 , 5 , 6 ])
161
- self .mock_runner .slot_mapping_cpu = torch .tensor (range (20 ))
162
- self .mock_runner .device = 'cpu:0'
163
- self .mock_runner .attn_mask = torch .ones ((15 , 15 ))
164
- self .mock_runner .attn_state = AscendAttentionState .ChunkedPrefill
165
- self .mock_runner .query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ])
166
-
167
- self .builder .build (num_reqs , num_actual_tokens , max_query_len )
154
+ common_attn_metadata = AscendCommonAttentionMetadata (
155
+ query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
156
+ query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
157
+ seq_lens_cpu = torch .tensor ([4 , 5 , 6 ]),
158
+ num_reqs = 3 ,
159
+ num_actual_tokens = 15 ,
160
+ max_query_len = 6 ,
161
+ decode_token_per_req = torch .tensor ([1 , 1 , 1 ]),
162
+ block_table_tensor = torch .zeros ((10 , 10 )),
163
+ slot_mapping_cpu = torch .tensor (range (20 )),
164
+ actual_seq_lengths_q = torch .tensor ([0 , 1 , 2 ]),
165
+ positions = torch .tensor ([10 , 10 ]),
166
+ attn_mask = torch .ones ((15 , 15 )),
167
+ spec_attn_mask = None ,
168
+ attn_state = AscendAttentionState .ChunkedPrefill )
169
+ mock_model = MagicMock ()
170
+
171
+ self .builder .build (common_attn_metadata , mock_model )
168
172
169
173
170
174
class TestAscendAttentionBackendImpl (TestBase ):
0 commit comments