21
21
compute_slot_mapping_start_idx ,
22
22
is_block_tables_empty )
23
23
from vllm .attention .ops .paged_attn import PagedAttention
24
- from vllm .sequence import SequenceGroupMetadata
25
24
from vllm .utils import get_kv_cache_torch_dtype , make_tensor_with_pad
26
25
27
26
if TYPE_CHECKING :
28
- from vllm .worker .model_runner import (GPUModelRunnerBase ,
29
- ModelInputForGPUBuilder )
27
+ from vllm .worker .model_runner import ModelInputForGPUBuilder
30
28
31
29
32
30
class FlashInferBackend (AttentionBackend ):
@@ -216,6 +214,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
216
214
self .num_prefill_tokens = 0
217
215
self .num_decode_tokens = 0
218
216
217
+ self .input_builder = input_builder
218
+ self .runner = input_builder .runner
219
+
219
220
self .sliding_window = input_builder .sliding_window
220
221
self .block_size = input_builder .block_size
221
222
self .use_v2_block_manager = (
@@ -238,26 +239,24 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
238
239
# paged_kv_last_page_len is the length of the last page of each request
239
240
self .paged_kv_last_page_len : List [int ] = []
240
241
241
- def add_seq_group (self , seq_group_metadata : SequenceGroupMetadata ,
242
- token_lens : List [int ], seq_lens : List [int ],
243
- curr_seq_lens : List [int ], query_lens : List [int ],
244
- context_lens : List [int ],
245
- curr_sliding_window_blocks : List [int ],
246
- prefix_cache_hit : bool , chunked_prefill_enabled : bool ):
242
+ def _add_seq_group (
243
+ self , inter_data : "ModelInputForGPUBuilder.InterDataForSeqGroup" ,
244
+ chunked_prefill_enabled : bool ):
247
245
"""Add a sequence group to the metadata. Specifically update/append
248
246
1. context length.
249
247
2. block table.
250
248
3. slot mapping.
251
249
"""
252
- is_prompt = seq_group_metadata .is_prompt
253
- block_tables = seq_group_metadata .block_tables
254
- computed_block_nums = seq_group_metadata .computed_block_nums
250
+ is_prompt = inter_data .is_prompt
251
+ block_tables = inter_data .block_tables
252
+ computed_block_nums = inter_data .computed_block_nums
255
253
256
254
for (seq_id , token_len , seq_len , curr_seq_len , query_len , context_len ,
257
255
curr_sliding_window_block ) in zip (
258
- seq_group_metadata .seq_data .keys (), token_lens , seq_lens ,
259
- curr_seq_lens , query_lens , context_lens ,
260
- curr_sliding_window_blocks ):
256
+ inter_data .seq_ids , [len (t ) for t in inter_data .input_tokens ],
257
+ inter_data .orig_seq_lens , inter_data .seq_lens ,
258
+ inter_data .query_lens , inter_data .context_lens ,
259
+ inter_data .curr_sliding_window_blocks ):
261
260
self .context_lens .append (context_len )
262
261
if is_prompt :
263
262
self .num_prefills += 1
@@ -275,7 +274,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
275
274
# only allowing multiple of block_size chunk size.
276
275
# NOTE: This only works for oooooooxxx style attention.
277
276
block_table = []
278
- if prefix_cache_hit :
277
+ if inter_data . prefix_cache_hit :
279
278
block_table = computed_block_nums
280
279
elif ((chunked_prefill_enabled or not is_prompt )
281
280
and block_tables is not None ):
@@ -290,8 +289,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
290
289
self .use_v2_block_manager )
291
290
compute_slot_mapping (is_profile_run , self .slot_mapping , seq_id ,
292
291
seq_len , context_len , start_idx ,
293
- self .block_size ,
294
- seq_group_metadata .block_tables )
292
+ self .block_size , inter_data .block_tables )
295
293
296
294
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
297
295
# and paged_kv_last_page_len for profile run because we will
@@ -317,9 +315,13 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
317
315
last_page_len = self .block_size
318
316
self .paged_kv_last_page_len .append (last_page_len )
319
317
320
- def build (self , runner : "GPUModelRunnerBase" , seq_lens , query_lens ,
318
+ def build (self , seq_lens : List [ int ], query_lens : List [ int ] ,
321
319
cuda_graph_pad_size : int , batch_size : int ):
322
- device = runner .device
320
+ for inter_data in self .input_builder .inter_data_list :
321
+ self ._add_seq_group (inter_data ,
322
+ self .input_builder .chunked_prefill_enabled )
323
+
324
+ device = self .runner .device
323
325
use_captured_graph = cuda_graph_pad_size != - 1
324
326
325
327
max_query_len = max (query_lens )
@@ -333,7 +335,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
333
335
334
336
# The shape of graph_block_tables is
335
337
# [max batch size, max context len // block size].
336
- input_block_tables = runner .graph_block_tables [:batch_size ]
338
+ input_block_tables = self . runner .graph_block_tables [:batch_size ]
337
339
for i , block_table in enumerate (self .block_tables ):
338
340
if block_table :
339
341
input_block_tables [i , :len (block_table )] = block_table
@@ -377,7 +379,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
377
379
dtype = torch .long ,
378
380
device = device )
379
381
380
- logits_soft_cap = getattr (runner .model_config .hf_config ,
382
+ logits_soft_cap = getattr (self . runner .model_config .hf_config ,
381
383
"attn_logit_softcapping" , None )
382
384
383
385
if len (self .paged_kv_indptr ) > 0 :
@@ -394,8 +396,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
394
396
paged_kv_indptr_tensor = None
395
397
paged_kv_last_page_len_tensor = None
396
398
397
- kv_cache_dtype = get_kv_cache_torch_dtype (runner . kv_cache_dtype ,
398
- runner .model_config .dtype )
399
+ kv_cache_dtype = get_kv_cache_torch_dtype (
400
+ self . runner . kv_cache_dtype , self . runner .model_config .dtype )
399
401
return FlashInferMetadata (
400
402
num_prefills = self .num_prefills ,
401
403
slot_mapping = slot_mapping_tensor ,
@@ -406,11 +408,11 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
406
408
paged_kv_indptr = paged_kv_indptr_tensor ,
407
409
paged_kv_indices = paged_kv_indices_tensor ,
408
410
paged_kv_last_page_len = paged_kv_last_page_len_tensor ,
409
- num_qo_heads = runner .model_config .get_num_attention_heads (
410
- runner .parallel_config ),
411
- num_kv_heads = runner .model_config .get_num_kv_heads (
412
- runner .parallel_config ),
413
- head_dim = runner .model_config .get_head_size (),
411
+ num_qo_heads = self . runner .model_config .get_num_attention_heads (
412
+ self . runner .parallel_config ),
413
+ num_kv_heads = self . runner .model_config .get_num_kv_heads (
414
+ self . runner .parallel_config ),
415
+ head_dim = self . runner .model_config .get_head_size (),
414
416
page_size = self .block_size ,
415
417
seq_start_loc = seq_start_loc ,
416
418
query_start_loc = query_start_loc ,
0 commit comments