1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
- from collections import defaultdict
5
4
from dataclasses import dataclass
6
5
from itertools import accumulate
7
- from typing import Dict , List , Optional , Tuple , Type
6
+ from typing import List , Optional , Tuple , Type
8
7
9
8
import torch
10
9
11
10
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
12
11
AttentionMetadata ,
13
12
AttentionMetadataBuilder )
14
13
from vllm .attention .backends .utils import CommonAttentionState
15
- from vllm .multimodal import MultiModalPlaceholderMap
16
14
from vllm .utils import async_tensor_h2d
17
15
18
16
# Placeholder attention backend for models like Mamba and pooling models that
@@ -141,8 +139,6 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
141
139
num_prefill_tokens = self .num_prefill_tokens ,
142
140
num_decode_tokens = 0 ,
143
141
slot_mapping = slot_mapping ,
144
- multi_modal_placeholder_index_maps = self .
145
- multi_modal_placeholder_index_maps ,
146
142
enable_kv_scales_calculation = self .enable_kv_scales_calculation ,
147
143
seq_lens = seq_lens ,
148
144
seq_lens_tensor = seq_lens_tensor ,
@@ -178,7 +174,6 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
178
174
num_prefill_tokens = 0 ,
179
175
num_decode_tokens = self .num_decode_tokens ,
180
176
slot_mapping = slot_mapping ,
181
- multi_modal_placeholder_index_maps = None ,
182
177
enable_kv_scales_calculation = True ,
183
178
seq_lens = None ,
184
179
seq_lens_tensor = seq_lens_tensor ,
@@ -210,9 +205,6 @@ def prepare(self):
210
205
self .prefill_seq_lens : List [int ] = []
211
206
self .context_lens : List [int ] = []
212
207
self .curr_seq_lens : List [int ] = []
213
- self .multimodal_placeholder_maps : Dict [
214
- str ,
215
- MultiModalPlaceholderMap ] = defaultdict (MultiModalPlaceholderMap )
216
208
self .num_prefills = 0
217
209
self .num_prefill_tokens = 0
218
210
self .num_decode_tokens = 0
@@ -232,12 +224,6 @@ def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool):
232
224
self .context_lens .append (context_len )
233
225
234
226
if is_prompt :
235
- mm_maps = inter_data .multi_modal_placeholder_maps
236
- if mm_maps :
237
- for modality , placeholders in mm_maps .items ():
238
- self .multimodal_placeholder_maps [modality ].extend (
239
- placeholders )
240
-
241
227
self .num_prefills += 1
242
228
self .num_prefill_tokens += token_len
243
229
self .prefill_seq_lens .append (seq_len )
@@ -295,20 +281,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
295
281
seq_start_loc_tensor = async_tensor_h2d (seq_start_loc , torch .int32 ,
296
282
device , self .runner .pin_memory )
297
283
298
- placeholder_index_maps = {
299
- modality : placeholder_map .index_map ()
300
- for modality , placeholder_map in
301
- self .multimodal_placeholder_maps .items ()
302
- }
303
-
304
284
# Placeholders
305
285
slot_mapping_tensor = torch .empty (0 )
306
286
block_tables = torch .empty (0 )
307
287
308
288
return PlaceholderAttentionMetadata (
309
289
num_prefills = self .num_prefills ,
310
290
slot_mapping = slot_mapping_tensor ,
311
- multi_modal_placeholder_index_maps = placeholder_index_maps ,
312
291
enable_kv_scales_calculation = True ,
313
292
num_prefill_tokens = self .num_prefill_tokens ,
314
293
num_decode_tokens = num_decode_tokens ,
0 commit comments