@@ -49,14 +49,20 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
49
49
50
50
51
51
@dataclass
52
- class DeepseekV32IndexerPrefillMetadata :
52
+ class DeepseekV32IndexerPrefillChunkMetadata :
53
53
block_table : torch .Tensor
54
- query_start_loc : torch .Tensor
55
- max_query_len : int
56
54
cu_seqlen_ks : torch .Tensor
57
55
cu_seqlen_ke : torch .Tensor
58
56
cu_seq_lens : torch .Tensor
59
57
total_seq_lens : int
58
+ token_start : int
59
+ token_end : int
60
+ num_reqs : int
61
+
62
+
63
+ @dataclass
64
+ class DeepseekV32IndexerPrefillMetadata :
65
+ chunks : list [DeepseekV32IndexerPrefillChunkMetadata ]
60
66
61
67
62
68
@dataclass
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
98
104
99
105
# TODO (zyongye) optimize this, this is now vibe coded
100
106
def kv_spans_from_batches (
101
- start_seq_loc : torch .Tensor ,
102
- seq_len_per_batch : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
107
+ start_seq_loc : torch .Tensor , seq_len_per_batch : torch . Tensor ,
108
+ device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
103
109
"""
104
110
Args:
105
111
start_seq_loc: 1D long tensor [B+1], cumulative counts of
@@ -122,15 +128,14 @@ def kv_spans_from_batches(
122
128
are the **last** `counts[i]` positions of that sequence.
123
129
"""
124
130
q = start_seq_loc .to (dtype = torch .long )
125
- L = seq_len_per_batch .to (dtype = torch .long , device = q . device )
131
+ L = seq_len_per_batch .to (dtype = torch .long )
126
132
assert q .dim () == 1 and L .dim () == 1
127
133
assert q .numel () == L .numel () + 1 , "start_seq_loc must have length B+1"
128
134
129
135
# Selected tokens per batch and totals
130
136
counts = q [1 :] - q [:- 1 ] # [B]
131
137
N = int (q [- 1 ].item ()) # total selected tokens
132
138
B = L .numel ()
133
- device = L .device
134
139
135
140
if N == 0 :
136
141
return (torch .empty (0 , dtype = torch .long , device = device ),
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
140
145
kv_starts_per_batch = torch .cumsum (L , dim = 0 ) - L # [B]
141
146
142
147
# For each selected token, which batch does it belong to?
143
- batch_id = torch .repeat_interleave (torch .arange (B , device = device ),
144
- counts ) # [N]
148
+ batch_id = torch .repeat_interleave (torch .arange (B ), counts ) # [N]
145
149
146
150
# Map batch KV start to each token
147
151
start_tensor = kv_starts_per_batch [batch_id ] # [N]
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
151
155
L_expand = torch .repeat_interleave (L , counts ) # [N]
152
156
m_expand = torch .repeat_interleave (counts , counts ) # [N]
153
157
# position within the selected block: 1..counts[b]
154
- pos_within = (torch .arange (N , device = device , dtype = torch .long ) -
158
+ pos_within = (torch .arange (N , dtype = torch .long ) -
155
159
torch .repeat_interleave (q [:- 1 ], counts ) + 1 )
156
160
157
161
local_pos = L_expand - m_expand + pos_within # [N], 1-based
158
162
end_location = start_tensor + local_pos # exclusive end
159
163
160
- return start_tensor .int (), end_location .int ()
164
+ return start_tensor .int (). to ( device ) , end_location .int (). to ( device )
161
165
162
166
163
167
def get_max_prefill_buffer_size (vllm_config : VllmConfig ):
164
168
max_model_len = vllm_config .model_config .max_model_len
165
- # max_num_batched_tokens = \
166
- # vllm_config.scheduler_config.max_num_batched_tokens
167
- max_num_seq = vllm_config .scheduler_config .max_num_seqs
168
- # NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
169
- return max_model_len * max_num_seq
169
+ # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
170
+ # May be tuned later.
171
+ return max_model_len * 2
172
+
173
+
174
+ def split_prefill_chunks (seq_lens_cpu : torch .Tensor ,
175
+ max_prefill_buffer_size : int ,
176
+ reqs_start : int ) -> list [tuple [int , int ]]:
177
+ """
178
+ Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
179
+ such that the total sequence length of each chunk is less than the
180
+ maximum prefill buffer size.
181
+
182
+ Args:
183
+ seq_lens_cpu: The sequence lengths of the prefill requests.
184
+ max_prefill_buffer_size: The maximum prefill buffer size.
185
+ reqs_start: The start index of the prefill requests.
186
+
187
+ Returns:
188
+ A list of tuples of (reqs_start, reqs_end).
189
+ """
190
+ chunk_seq_ids = []
191
+ total_seq_lens = 0
192
+ for i in range (reqs_start , len (seq_lens_cpu )):
193
+ cur_seq_len = seq_lens_cpu [i ].item ()
194
+ assert cur_seq_len <= max_prefill_buffer_size
195
+ total_seq_lens += cur_seq_len
196
+ if total_seq_lens > max_prefill_buffer_size :
197
+ chunk_seq_ids .append ((reqs_start , i ))
198
+ reqs_start = i
199
+ total_seq_lens = cur_seq_len
200
+ if total_seq_lens > 0 :
201
+ chunk_seq_ids .append ((reqs_start , len (seq_lens_cpu )))
202
+ return chunk_seq_ids
170
203
171
204
172
205
class DeepseekV32IndexerMetadataBuilder (AttentionMetadataBuilder ):
@@ -201,6 +234,33 @@ def __init__(self, *args, **kwargs):
201
234
dtype = torch .int32 ,
202
235
device = self .device )
203
236
237
+ def build_one_prefill_chunk (self , reqs_start , reqs_end ,
238
+ query_start_loc_cpu , seq_lens_cpu ,
239
+ block_table ):
240
+ prefill_query_start_loc = query_start_loc_cpu [
241
+ reqs_start :reqs_end + 1 ] - query_start_loc_cpu [reqs_start ]
242
+ cu_seqlen_ks , cu_seqlen_ke = kv_spans_from_batches (
243
+ prefill_query_start_loc , seq_lens_cpu [reqs_start :reqs_end ],
244
+ self .device )
245
+ token_start = query_start_loc_cpu [reqs_start ].item ()
246
+ token_end = query_start_loc_cpu [reqs_end ].item ()
247
+ total_seq_lens = seq_lens_cpu [reqs_start :reqs_end ].sum ()
248
+ assert total_seq_lens <= self .max_prefill_buffer_size
249
+ cu_seq_lens = torch .cat ([
250
+ torch .zeros (1 , dtype = torch .int32 ),
251
+ seq_lens_cpu [reqs_start :reqs_end ].cumsum (dim = 0 )
252
+ ]).to (torch .int32 ).to (self .device )
253
+ return DeepseekV32IndexerPrefillChunkMetadata (
254
+ cu_seqlen_ks = cu_seqlen_ks ,
255
+ cu_seqlen_ke = cu_seqlen_ke ,
256
+ cu_seq_lens = cu_seq_lens ,
257
+ total_seq_lens = total_seq_lens ,
258
+ block_table = block_table [reqs_start :reqs_end ],
259
+ token_start = token_start ,
260
+ token_end = token_end ,
261
+ num_reqs = reqs_end - reqs_start ,
262
+ )
263
+
204
264
def build (self ,
205
265
common_prefix_len : int ,
206
266
common_attn_metadata : CommonAttentionMetadata ,
@@ -209,11 +269,7 @@ def build(self,
209
269
num_reqs = common_attn_metadata .num_reqs
210
270
num_tokens = common_attn_metadata .num_actual_tokens
211
271
212
- device = self .device
213
- block_table_tensor = common_attn_metadata .block_table_tensor
214
-
215
- query_start_loc = common_attn_metadata .query_start_loc
216
-
272
+ query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
217
273
num_decodes , num_prefills , num_decode_tokens , num_prefill_tokens = \
218
274
split_decodes_and_prefills (
219
275
common_attn_metadata ,
@@ -224,27 +280,20 @@ def build(self,
224
280
225
281
prefill_metadata = None
226
282
if num_prefills > 0 :
227
- reqs_start = num_decodes
228
- prefill_query_start_loc = query_start_loc [
229
- reqs_start :] - query_start_loc [reqs_start ]
230
- cu_seqlen_ks , cu_seqlen_ke = kv_spans_from_batches (
231
- prefill_query_start_loc ,
232
- common_attn_metadata .seq_lens [reqs_start :])
233
- total_seq_lens = common_attn_metadata .seq_lens [reqs_start :].sum ()
234
- assert total_seq_lens < self .max_prefill_buffer_size
235
- cu_seq_lens = torch .cat ([
236
- torch .zeros (1 , dtype = torch .int32 , device = device ),
237
- common_attn_metadata .seq_lens [reqs_start :].cumsum (dim = 0 )
238
- ]).to (torch .int32 ).cuda ()
239
- prefill_metadata = DeepseekV32IndexerPrefillMetadata (
240
- block_table = block_table_tensor [reqs_start :, ...],
241
- query_start_loc = prefill_query_start_loc ,
242
- max_query_len = common_attn_metadata .max_query_len ,
243
- cu_seqlen_ks = cu_seqlen_ks ,
244
- cu_seqlen_ke = cu_seqlen_ke ,
245
- cu_seq_lens = cu_seq_lens ,
246
- total_seq_lens = total_seq_lens ,
283
+ chunk_seq_ids = split_prefill_chunks (
284
+ common_attn_metadata .seq_lens_cpu ,
285
+ self .max_prefill_buffer_size ,
286
+ num_decodes ,
247
287
)
288
+ chunks = [
289
+ self .build_one_prefill_chunk (
290
+ reqs_start , reqs_end , query_start_loc_cpu ,
291
+ common_attn_metadata .seq_lens_cpu ,
292
+ common_attn_metadata .block_table_tensor )
293
+ for reqs_start , reqs_end in chunk_seq_ids
294
+ ]
295
+ prefill_metadata = DeepseekV32IndexerPrefillMetadata (
296
+ chunks = chunks , )
248
297
249
298
decode_metadata = None
250
299
if num_decodes > 0 :
0 commit comments