Skip to content

Commit 8e5314a

Browse files
authored
[V1] Add disable_chunked_mm_input arg to disable partial mm input prefill (#15837)
Signed-off-by: mgoin <[email protected]>
1 parent 87918e4 commit 8e5314a

File tree

4 files changed

+80
-0
lines changed

4 files changed

+80
-0
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create_scheduler(
2424
max_num_batched_tokens: int = 8192,
2525
enable_prefix_caching: Optional[bool] = None,
2626
long_prefill_token_threshold: int = 0,
27+
disable_chunked_mm_input: bool = False,
2728
) -> Scheduler:
2829
'''Create scheduler under test.
2930
@@ -43,6 +44,7 @@ def create_scheduler(
4344
max_num_batched_tokens=max_num_batched_tokens,
4445
max_model_len=max_num_batched_tokens,
4546
long_prefill_token_threshold=long_prefill_token_threshold,
47+
disable_chunked_mm_input=disable_chunked_mm_input,
4648
)
4749
model_config = ModelConfig(
4850
model=model,
@@ -278,6 +280,49 @@ def test_schedule_partial_requests():
278280
assert requests[2].request_id not in output.num_scheduled_tokens
279281

280282

283+
def test_no_mm_input_chunking():
284+
# Disable multimodal input chunking.
285+
scheduler = create_scheduler(
286+
model="llava-hf/llava-1.5-7b-hf",
287+
max_num_batched_tokens=1024,
288+
disable_chunked_mm_input=True,
289+
)
290+
mm_positions = [[PlaceholderRange(offset=400, length=800)]]
291+
requests = create_requests(num_requests=1,
292+
num_tokens=1200,
293+
mm_positions=mm_positions)
294+
for request in requests:
295+
scheduler.add_request(request)
296+
297+
output = scheduler.schedule()
298+
assert len(output.scheduled_new_reqs) == 1
299+
assert len(output.scheduled_cached_reqs) == 0
300+
assert len(output.finished_req_ids) == 0
301+
# We want to only see the 400 text tokens at the start scheduled
302+
assert output.num_scheduled_tokens[requests[0].request_id] == 400
303+
304+
req_to_index = {
305+
request.request_id: i
306+
for i, request in enumerate(requests)
307+
}
308+
model_runner_output = ModelRunnerOutput(
309+
req_ids=[request.request_id for request in requests],
310+
req_id_to_index=req_to_index,
311+
sampled_token_ids=[[] for _ in range(len(requests))],
312+
spec_token_ids=None,
313+
logprobs=None,
314+
prompt_logprobs_dict={},
315+
)
316+
scheduler.update_from_output(output, model_runner_output)
317+
318+
output = scheduler.schedule()
319+
assert len(scheduler.running) == 1
320+
assert len(output.scheduled_new_reqs) == 0
321+
assert len(output.scheduled_cached_reqs) == 1
322+
assert len(output.finished_req_ids) == 0
323+
assert output.num_scheduled_tokens[requests[0].request_id] == 800
324+
325+
281326
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
282327
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
283328
"""Test scheduling behavior with concurrent partial requests.

vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,14 @@ class SchedulerConfig:
17211721

17221722
chunked_prefill_enabled: bool = field(init=False)
17231723

1724+
# If set to true and chunked prefill is enabled, we do not want to
1725+
# partially schedule a multimodal item. Only used in V1
1726+
# This ensures that if a request has a mixed prompt
1727+
# (like text tokens TTTT followed by image tokens IIIIIIIIII) where only
1728+
# some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
1729+
# it will be scheduled as TTTT in one step and IIIIIIIIII in the next.
1730+
disable_chunked_mm_input: bool = False
1731+
17241732
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
17251733
# or "mod.custom_class".
17261734
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"

vllm/engine/arg_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class EngineArgs:
179179

180180
scheduler_delay_factor: float = 0.0
181181
enable_chunked_prefill: Optional[bool] = None
182+
disable_chunked_mm_input: bool = False
182183

183184
guided_decoding_backend: str = 'xgrammar'
184185
logits_processor_pattern: Optional[str] = None
@@ -1017,6 +1018,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10171018
"Note that even if this is set to False, cascade attention will be "
10181019
"only used when the heuristic tells that it's beneficial.")
10191020

1021+
parser.add_argument(
1022+
"--disable-chunked-mm-input",
1023+
action=StoreBoolean,
1024+
default=EngineArgs.disable_chunked_mm_input,
1025+
nargs="?",
1026+
const="False",
1027+
help="Disable multimodal input chunking attention for V1. "
1028+
"If set to true and chunked prefill is enabled, we do not want to"
1029+
" partially schedule a multimodal item. This ensures that if a "
1030+
"request has a mixed prompt (like text tokens TTTT followed by "
1031+
"image tokens IIIIIIIIII) where only some image tokens can be "
1032+
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
1033+
"as TTTT in one step and IIIIIIIIII in the next.")
1034+
10201035
return parser
10211036

10221037
@classmethod
@@ -1261,6 +1276,7 @@ def create_engine_config(
12611276
num_lookahead_slots=num_lookahead_slots,
12621277
delay_factor=self.scheduler_delay_factor,
12631278
enable_chunked_prefill=self.enable_chunked_prefill,
1279+
disable_chunked_mm_input=self.disable_chunked_mm_input,
12641280
is_multimodal_model=model_config.is_multimodal_model,
12651281
preemption_mode=self.preemption_mode,
12661282
num_scheduler_steps=self.num_scheduler_steps,

vllm/v1/core/sched/scheduler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,17 @@ def _try_schedule_encoder_inputs(
522522
if self.encoder_cache_manager.has_cache(request, i):
523523
# The encoder input is already computed and cached.
524524
continue
525+
526+
# If no encoder input chunking is allowed, we do not want to
527+
# partially schedule a multimodal item. If the scheduled range would
528+
# only cover part of the mm input, roll back to before the mm item.
529+
if (self.scheduler_config.disable_chunked_mm_input
530+
and num_computed_tokens < start_pos
531+
and (num_computed_tokens + num_new_tokens)
532+
< (start_pos + num_encoder_tokens)):
533+
num_new_tokens = start_pos - num_computed_tokens
534+
break
535+
525536
if (not self.encoder_cache_manager.can_allocate(request, i)
526537
or num_encoder_tokens > encoder_budget):
527538
# The encoder cache is full or the encoder budget is exhausted.

0 commit comments

Comments
 (0)