Skip to content

Commit e0c1575

Browse files
authored
[Core] Modulize prepare input and attention metadata builder (#6596)
1 parent bdf5fd1 commit e0c1575

File tree

6 files changed

+409
-298
lines changed

6 files changed

+409
-298
lines changed

vllm/attention/backends/abstract.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88

99
if TYPE_CHECKING:
10-
from vllm.sequence import SequenceGroupMetadata
1110
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
1211

1312

@@ -128,25 +127,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
128127
"""Abstract class for attention metadata builders."""
129128

130129
@abstractmethod
131-
def __init__(self, input_builder) -> None:
130+
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
132131
raise NotImplementedError
133132

134133
@abstractmethod
135-
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
136-
token_lens: List[int], seq_lens: List[int],
137-
curr_seq_lens: List[int], query_lens: List[int],
138-
context_lens: List[int],
139-
curr_sliding_window_blocks: List[int],
140-
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
141-
"""Add a sequence group to the metadata and update
142-
corresponding fields (in Python objects).
143-
"""
144-
raise NotImplementedError
145-
146-
@abstractmethod
147-
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
148-
query_lens: List[int], cuda_graph_pad_size: int,
149-
batch_size: int) -> T:
134+
def build(self, seq_lens: List[int], query_lens: List[int],
135+
cuda_graph_pad_size: int, batch_size: int) -> T:
150136
"""Build attention metadata with on-device tensors."""
151137
raise NotImplementedError
152138

vllm/attention/backends/flash_attn.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
1414
compute_slot_mapping_start_idx,
1515
is_block_tables_empty)
16-
from vllm.sequence import SequenceGroupMetadata
1716
from vllm.utils import make_tensor_with_pad
1817

1918
if TYPE_CHECKING:
20-
from vllm.worker.model_runner import (GPUModelRunnerBase,
21-
ModelInputForGPUBuilder)
19+
from vllm.worker.model_runner import ModelInputForGPUBuilder
2220

2321

2422
class FlashAttentionBackend(AttentionBackend):
@@ -212,30 +210,30 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
212210
self.num_prefill_tokens = 0
213211
self.num_decode_tokens = 0
214212

213+
self.input_builder = input_builder
214+
self.runner = input_builder.runner
215215
self.sliding_window = input_builder.sliding_window
216216
self.block_size = input_builder.block_size
217217
self.use_v2_block_manager = (
218218
input_builder.scheduler_config.use_v2_block_manager)
219219

220-
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
221-
token_lens: List[int], seq_lens: List[int],
222-
curr_seq_lens: List[int], query_lens: List[int],
223-
context_lens: List[int],
224-
curr_sliding_window_blocks: List[int],
225-
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
220+
def _add_seq_group(
221+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
222+
chunked_prefill_enabled: bool):
226223
"""Add a sequence group to the metadata. Specifically update/append
227224
1. context length.
228225
2. block table.
229226
3. slot mapping.
230227
"""
231-
is_prompt = seq_group_metadata.is_prompt
232-
block_tables = seq_group_metadata.block_tables
228+
is_prompt = inter_data.is_prompt
229+
block_tables = inter_data.block_tables
233230

234231
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
235232
curr_sliding_window_block) in zip(
236-
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
237-
curr_seq_lens, query_lens, context_lens,
238-
curr_sliding_window_blocks):
233+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
234+
inter_data.orig_seq_lens, inter_data.seq_lens,
235+
inter_data.query_lens, inter_data.context_lens,
236+
inter_data.curr_sliding_window_blocks):
239237
self.context_lens.append(context_len)
240238

241239
if is_prompt:
@@ -254,7 +252,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
254252
# only allowing multiple of block_size chunk size.
255253
# NOTE: This only works for oooooooxxx style attention.
256254
block_table = []
257-
if prefix_cache_hit:
255+
if inter_data.prefix_cache_hit:
258256
# NOTE(woosuk): For flash-attn, the block table should
259257
# include the entries for the incoming prefill tokens.
260258
block_table = block_tables[seq_id]
@@ -270,16 +268,19 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
270268
self.use_v2_block_manager)
271269
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
272270
seq_len, context_len, start_idx,
273-
self.block_size,
274-
seq_group_metadata.block_tables)
271+
self.block_size, inter_data.block_tables)
275272

276-
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
273+
def build(self, seq_lens: List[int], query_lens: List[int],
277274
cuda_graph_pad_size: int, batch_size: int):
278275
"""Build attention metadata with on-device tensors."""
279-
device = runner.device
276+
for inter_data in self.input_builder.inter_data_list:
277+
self._add_seq_group(inter_data,
278+
self.input_builder.chunked_prefill_enabled)
279+
280+
device = self.runner.device
280281
use_captured_graph = cuda_graph_pad_size != -1
281282

282-
logits_soft_cap = getattr(runner.model_config.hf_config,
283+
logits_soft_cap = getattr(self.runner.model_config.hf_config,
283284
"attn_logit_softcapping", None)
284285
if logits_soft_cap is not None:
285286
raise ValueError(
@@ -300,7 +301,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
300301

301302
# The shape of graph_block_tables is
302303
# [max batch size, max context len // block size].
303-
input_block_tables = runner.graph_block_tables[:batch_size]
304+
input_block_tables = self.runner.graph_block_tables[:batch_size]
304305
for i, block_table in enumerate(self.block_tables):
305306
if block_table:
306307
input_block_tables[i, :len(block_table)] = block_table

vllm/attention/backends/flashinfer.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
compute_slot_mapping_start_idx,
2222
is_block_tables_empty)
2323
from vllm.attention.ops.paged_attn import PagedAttention
24-
from vllm.sequence import SequenceGroupMetadata
2524
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
2625

2726
if TYPE_CHECKING:
28-
from vllm.worker.model_runner import (GPUModelRunnerBase,
29-
ModelInputForGPUBuilder)
27+
from vllm.worker.model_runner import ModelInputForGPUBuilder
3028

3129

3230
class FlashInferBackend(AttentionBackend):
@@ -216,6 +214,9 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
216214
self.num_prefill_tokens = 0
217215
self.num_decode_tokens = 0
218216

217+
self.input_builder = input_builder
218+
self.runner = input_builder.runner
219+
219220
self.sliding_window = input_builder.sliding_window
220221
self.block_size = input_builder.block_size
221222
self.use_v2_block_manager = (
@@ -238,26 +239,24 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
238239
# paged_kv_last_page_len is the length of the last page of each request
239240
self.paged_kv_last_page_len: List[int] = []
240241

241-
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
242-
token_lens: List[int], seq_lens: List[int],
243-
curr_seq_lens: List[int], query_lens: List[int],
244-
context_lens: List[int],
245-
curr_sliding_window_blocks: List[int],
246-
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
242+
def _add_seq_group(
243+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
244+
chunked_prefill_enabled: bool):
247245
"""Add a sequence group to the metadata. Specifically update/append
248246
1. context length.
249247
2. block table.
250248
3. slot mapping.
251249
"""
252-
is_prompt = seq_group_metadata.is_prompt
253-
block_tables = seq_group_metadata.block_tables
254-
computed_block_nums = seq_group_metadata.computed_block_nums
250+
is_prompt = inter_data.is_prompt
251+
block_tables = inter_data.block_tables
252+
computed_block_nums = inter_data.computed_block_nums
255253

256254
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
257255
curr_sliding_window_block) in zip(
258-
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
259-
curr_seq_lens, query_lens, context_lens,
260-
curr_sliding_window_blocks):
256+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
257+
inter_data.orig_seq_lens, inter_data.seq_lens,
258+
inter_data.query_lens, inter_data.context_lens,
259+
inter_data.curr_sliding_window_blocks):
261260
self.context_lens.append(context_len)
262261
if is_prompt:
263262
self.num_prefills += 1
@@ -275,7 +274,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
275274
# only allowing multiple of block_size chunk size.
276275
# NOTE: This only works for oooooooxxx style attention.
277276
block_table = []
278-
if prefix_cache_hit:
277+
if inter_data.prefix_cache_hit:
279278
block_table = computed_block_nums
280279
elif ((chunked_prefill_enabled or not is_prompt)
281280
and block_tables is not None):
@@ -290,8 +289,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
290289
self.use_v2_block_manager)
291290
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
292291
seq_len, context_len, start_idx,
293-
self.block_size,
294-
seq_group_metadata.block_tables)
292+
self.block_size, inter_data.block_tables)
295293

296294
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
297295
# and paged_kv_last_page_len for profile run because we will
@@ -317,9 +315,13 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
317315
last_page_len = self.block_size
318316
self.paged_kv_last_page_len.append(last_page_len)
319317

320-
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
318+
def build(self, seq_lens: List[int], query_lens: List[int],
321319
cuda_graph_pad_size: int, batch_size: int):
322-
device = runner.device
320+
for inter_data in self.input_builder.inter_data_list:
321+
self._add_seq_group(inter_data,
322+
self.input_builder.chunked_prefill_enabled)
323+
324+
device = self.runner.device
323325
use_captured_graph = cuda_graph_pad_size != -1
324326

325327
max_query_len = max(query_lens)
@@ -333,7 +335,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
333335

334336
# The shape of graph_block_tables is
335337
# [max batch size, max context len // block size].
336-
input_block_tables = runner.graph_block_tables[:batch_size]
338+
input_block_tables = self.runner.graph_block_tables[:batch_size]
337339
for i, block_table in enumerate(self.block_tables):
338340
if block_table:
339341
input_block_tables[i, :len(block_table)] = block_table
@@ -377,7 +379,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
377379
dtype=torch.long,
378380
device=device)
379381

380-
logits_soft_cap = getattr(runner.model_config.hf_config,
382+
logits_soft_cap = getattr(self.runner.model_config.hf_config,
381383
"attn_logit_softcapping", None)
382384

383385
if len(self.paged_kv_indptr) > 0:
@@ -394,8 +396,8 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
394396
paged_kv_indptr_tensor = None
395397
paged_kv_last_page_len_tensor = None
396398

397-
kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype,
398-
runner.model_config.dtype)
399+
kv_cache_dtype = get_kv_cache_torch_dtype(
400+
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
399401
return FlashInferMetadata(
400402
num_prefills=self.num_prefills,
401403
slot_mapping=slot_mapping_tensor,
@@ -406,11 +408,11 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
406408
paged_kv_indptr=paged_kv_indptr_tensor,
407409
paged_kv_indices=paged_kv_indices_tensor,
408410
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
409-
num_qo_heads=runner.model_config.get_num_attention_heads(
410-
runner.parallel_config),
411-
num_kv_heads=runner.model_config.get_num_kv_heads(
412-
runner.parallel_config),
413-
head_dim=runner.model_config.get_head_size(),
411+
num_qo_heads=self.runner.model_config.get_num_attention_heads(
412+
self.runner.parallel_config),
413+
num_kv_heads=self.runner.model_config.get_num_kv_heads(
414+
self.runner.parallel_config),
415+
head_dim=self.runner.model_config.get_head_size(),
414416
page_size=self.block_size,
415417
seq_start_loc=seq_start_loc,
416418
query_start_loc=query_start_loc,

vllm/attention/backends/utils.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55

66
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
7-
from vllm.sequence import SequenceGroupMetadata
87
from vllm.utils import make_tensor_with_pad
98

109
# Error string(s) for encoder/decoder
@@ -15,8 +14,7 @@
1514
PAD_SLOT_ID = -1
1615

1716
if TYPE_CHECKING:
18-
from vllm.worker.model_runner import (GPUModelRunnerBase,
19-
ModelInputForGPUBuilder)
17+
from vllm.worker.model_runner import ModelInputForGPUBuilder
2018

2119

2220
def is_block_tables_empty(block_tables: Union[None, Dict]):
@@ -95,26 +93,27 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
9593
self.num_prefill_tokens = 0
9694
self.num_decode_tokens = 0
9795

96+
self.input_builder = input_builder
97+
self.runner = input_builder.runner
98+
9899
self.sliding_window = input_builder.sliding_window
99100
self.block_size = input_builder.block_size
100101
self.use_v2_block_manager = (
101102
input_builder.scheduler_config.use_v2_block_manager)
102103

103-
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
104-
token_lens: List[int], seq_lens: List[int],
105-
curr_seq_lens: List[int], query_lens: List[int],
106-
context_lens: List[int],
107-
curr_sliding_window_blocks: List[int], prefix_cache_hit,
108-
chunked_prefill_enabled):
109-
is_prompt = seq_group_metadata.is_prompt
110-
block_tables = seq_group_metadata.block_tables
111-
computed_block_nums = seq_group_metadata.computed_block_nums
104+
def _add_seq_group(
105+
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
106+
chunked_prefill_enabled: bool):
107+
is_prompt = inter_data.is_prompt
108+
block_tables = inter_data.block_tables
109+
computed_block_nums = inter_data.computed_block_nums
112110

113111
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
114112
curr_sliding_window_block) in zip(
115-
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
116-
curr_seq_lens, query_lens, context_lens,
117-
curr_sliding_window_blocks):
113+
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
114+
inter_data.orig_seq_lens, inter_data.seq_lens,
115+
inter_data.query_lens, inter_data.context_lens,
116+
inter_data.curr_sliding_window_blocks):
118117
self.context_lens.append(context_len)
119118
if is_prompt:
120119
self.num_prefills += 1
@@ -132,7 +131,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
132131
# only allowing multiple of block_size chunk size.
133132
# NOTE: This only works for oooooooxxx style attention.
134133
block_table = []
135-
if prefix_cache_hit:
134+
if inter_data.prefix_cache_hit:
136135
block_table = computed_block_nums
137136
elif ((chunked_prefill_enabled or not is_prompt)
138137
and block_tables is not None):
@@ -146,16 +145,18 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
146145
self.use_v2_block_manager)
147146
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
148147
seq_len, context_len, start_idx,
149-
self.block_size,
150-
seq_group_metadata.block_tables)
148+
self.block_size, inter_data.block_tables)
149+
150+
def build(self, seq_lens: List[int], query_lens: List[int],
151+
cuda_graph_pad_size: int, batch_size: int):
152+
for inter_data in self.input_builder.inter_data_list:
153+
self._add_seq_group(inter_data,
154+
self.input_builder.chunked_prefill_enabled)
151155

152-
def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
153-
query_lens: List[int], cuda_graph_pad_size: int,
154-
batch_size: int):
155-
device = runner.device
156+
device = self.runner.device
156157
use_captured_graph = cuda_graph_pad_size != -1
157158

158-
logits_soft_cap = getattr(runner.model_config.hf_config,
159+
logits_soft_cap = getattr(self.runner.model_config.hf_config,
159160
"attn_logit_softcapping", None)
160161
if logits_soft_cap is not None:
161162
raise ValueError(
@@ -176,7 +177,7 @@ def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
176177

177178
# The shape of graph_block_tables is
178179
# [max batch size, max context len // block size].
179-
input_block_tables = runner.graph_block_tables[:batch_size]
180+
input_block_tables = self.runner.graph_block_tables[:batch_size]
180181
for i, block_table in enumerate(self.block_tables):
181182
if block_table:
182183
input_block_tables[i, :len(block_table)] = block_table

vllm/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,11 @@ def merge_dicts(dict1: Dict[K, List[T]],
719719
return dict(merged_dict)
720720

721721

722+
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
723+
"""Flatten a list of lists to a single list."""
724+
return [item for sublist in lists for item in sublist]
725+
726+
722727
def init_cached_hf_modules() -> None:
723728
"""
724729
Lazy initialization of the Hugging Face modules.

0 commit comments

Comments
 (0)