Skip to content

Commit c9b38be

Browse files
authored
[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 0dd3f4f commit c9b38be

File tree

13 files changed

+100
-64
lines changed

13 files changed

+100
-64
lines changed

tests/v1/core/test_async_scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def _make_model_runner_output(
2222
for i, req_id in enumerate(req_ids)
2323
},
2424
sampled_token_ids=[[i] for i in range(len(req_ids))],
25-
spec_token_ids=None,
2625
logprobs=None,
2726
prompt_logprobs_dict={},
2827
pooler_output=[],

tests/v1/core/test_scheduler.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.v1.core.sched.scheduler import Scheduler
1515
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1616
KVCacheGroupSpec)
17-
from vllm.v1.outputs import ModelRunnerOutput
17+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
1818
from vllm.v1.request import Request, RequestStatus
1919
from vllm.v1.structured_output import StructuredOutputManager
2020
from vllm.v1.structured_output.request import StructuredOutputRequest
@@ -158,7 +158,6 @@ def test_schedule_partial_requests():
158158
# Only the first request has a sampled token id because
159159
# the rest requests are still being prefilled.
160160
sampled_token_ids=[[0], [], []],
161-
spec_token_ids=None,
162161
logprobs=None,
163162
prompt_logprobs_dict={},
164163
pooler_output=[],
@@ -209,7 +208,6 @@ def test_no_mm_input_chunking():
209208
req_ids=[request.request_id for request in requests],
210209
req_id_to_index=req_to_index,
211210
sampled_token_ids=[[] for _ in range(len(requests))],
212-
spec_token_ids=None,
213211
logprobs=None,
214212
prompt_logprobs_dict={},
215213
pooler_output=[],
@@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
273271
req_ids=[request.request_id for request in requests],
274272
req_id_to_index=req_to_index,
275273
sampled_token_ids=[[] for _ in range(len(requests))],
276-
spec_token_ids=None,
277274
logprobs=None,
278275
prompt_logprobs_dict={},
279276
pooler_output=[],
@@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
298295
req_ids=[request.request_id for request in requests],
299296
req_id_to_index=req_to_index,
300297
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
301-
spec_token_ids=None,
302298
logprobs=None,
303299
prompt_logprobs_dict={},
304300
pooler_output=[],
@@ -355,7 +351,6 @@ def test_stop_via_update_from_output():
355351
sampled_token_ids=[[EOS_TOKEN_ID],
356352
[10,
357353
11]], # First request hits EOS, second continues
358-
spec_token_ids=None,
359354
logprobs=None,
360355
prompt_logprobs_dict={},
361356
pooler_output=[])
@@ -409,7 +404,6 @@ def test_stop_via_update_from_output():
409404
},
410405
sampled_token_ids=[[10, 42, 12],
411406
[13, 14]], # First request hits stop token
412-
spec_token_ids=None,
413407
logprobs=None,
414408
prompt_logprobs_dict={},
415409
pooler_output=[])
@@ -462,7 +456,6 @@ def test_stop_via_update_from_output():
462456
},
463457
sampled_token_ids=[[10, 11, 12],
464458
[13]], # First request exceeds max_tokens
465-
spec_token_ids=None,
466459
logprobs=None,
467460
prompt_logprobs_dict={},
468461
pooler_output=[])
@@ -505,7 +498,6 @@ def test_stop_via_update_from_output():
505498
req_ids=[requests[0].request_id],
506499
req_id_to_index={requests[0].request_id: 0},
507500
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
508-
spec_token_ids=None,
509501
logprobs=None,
510502
prompt_logprobs_dict={},
511503
pooler_output=[])
@@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
554546
req_ids=[requests[0].request_id],
555547
req_id_to_index={requests[0].request_id: 0},
556548
sampled_token_ids=[[0]],
557-
spec_token_ids=None,
558549
logprobs=None,
559550
prompt_logprobs_dict={},
560551
pooler_output=[],
@@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
572563
req_ids=[requests[1].request_id],
573564
req_id_to_index={requests[1].request_id: 0},
574565
sampled_token_ids=[[0]],
575-
spec_token_ids=None,
576566
logprobs=None,
577567
prompt_logprobs_dict={},
578568
pooler_output=[],
@@ -608,7 +598,6 @@ def test_preempt_during_execution():
608598
req_ids=[requests[0].request_id],
609599
req_id_to_index={requests[0].request_id: 0},
610600
sampled_token_ids=[[0]],
611-
spec_token_ids=None,
612601
logprobs=None,
613602
prompt_logprobs_dict={},
614603
pooler_output=[],
@@ -626,7 +615,6 @@ def test_preempt_during_execution():
626615
req_ids=[requests[1].request_id],
627616
req_id_to_index={requests[1].request_id: 0},
628617
sampled_token_ids=[[42]],
629-
spec_token_ids=None,
630618
logprobs=None,
631619
prompt_logprobs_dict={},
632620
pooler_output=[],
@@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
682670
req_ids=req_ids,
683671
req_id_to_index=req_to_index,
684672
sampled_token_ids=[[0] for _ in range(len(requests))],
685-
spec_token_ids=spec_tokens,
686673
logprobs=None,
687674
prompt_logprobs_dict={},
688675
pooler_output=[],
689676
)
690677
engine_core_outputs = scheduler.update_from_output(output,
691678
model_runner_output)
679+
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
680+
scheduler.update_draft_token_ids(draft_token_ids)
692681

693682
for i in range(len(requests)):
694683
running_req = scheduler.running[i]
@@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
722711
req_ids=req_ids,
723712
req_id_to_index=req_to_index,
724713
sampled_token_ids=output_tokens,
725-
spec_token_ids=None,
726714
logprobs=None,
727715
prompt_logprobs_dict={},
728716
pooler_output=[],
@@ -851,7 +839,6 @@ def test_kv_connector_basic():
851839
req_ids=req_ids,
852840
req_id_to_index=req_to_index,
853841
sampled_token_ids=[[1000]] * len(req_ids),
854-
spec_token_ids=None,
855842
logprobs=None,
856843
prompt_logprobs_dict={},
857844
pooler_output=[],
@@ -898,7 +885,6 @@ def test_kv_connector_basic():
898885
req_ids=req_ids,
899886
req_id_to_index=req_to_index,
900887
sampled_token_ids=[[1000]] * len(req_ids),
901-
spec_token_ids=None,
902888
logprobs=None,
903889
prompt_logprobs_dict={},
904890
pooler_output=[],
@@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate():
966952
req_ids=req_ids,
967953
req_id_to_index=req_to_index,
968954
sampled_token_ids=[[1000]] * len(req_ids),
969-
spec_token_ids=None,
970955
logprobs=None,
971956
prompt_logprobs_dict={},
972957
pooler_output=[],
@@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption():
10481033
req_ids=req_ids,
10491034
req_id_to_index=req_to_index,
10501035
sampled_token_ids=[[1000]] * len(req_ids),
1051-
spec_token_ids=None,
10521036
logprobs=None,
10531037
prompt_logprobs_dict={},
10541038
pooler_output=[],
@@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler):
11421126
for i, req in enumerate(scheduler.running)
11431127
},
11441128
sampled_token_ids=[[1000]] * len(scheduler.running),
1145-
spec_token_ids=None,
11461129
logprobs=None,
11471130
prompt_logprobs_dict={},
11481131
pooler_output=[],
@@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption():
14681451
for i, req in enumerate(low_priority_requests)
14691452
},
14701453
sampled_token_ids=[[100] for _ in low_priority_requests],
1471-
spec_token_ids=None,
14721454
logprobs=None,
14731455
prompt_logprobs_dict={},
14741456
pooler_output=[],
@@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
15411523
for i, req in enumerate(low_priority_requests)
15421524
},
15431525
sampled_token_ids=[[100] for _ in low_priority_requests],
1544-
spec_token_ids=None,
15451526
logprobs=None,
15461527
prompt_logprobs_dict={},
15471528
pooler_output=[],
@@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property():
17831764
req_ids=[req.req_id],
17841765
req_id_to_index={req.req_id: 0},
17851766
sampled_token_ids=[[100]],
1786-
spec_token_ids=None,
17871767
logprobs=None,
17881768
prompt_logprobs_dict={},
17891769
pooler_output=[],

tests/v1/kv_connector/unit/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ def create_model_runner_output(
200200
req_ids=req_ids,
201201
req_id_to_index=req_id_to_index,
202202
sampled_token_ids=sampled_token_ids,
203-
spec_token_ids=None,
204203
logprobs=None,
205204
prompt_logprobs_dict={},
206205
pooler_output=None,

vllm/v1/core/sched/interface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.v1.core.sched.output import SchedulerOutput
1010
from vllm.v1.engine import EngineCoreOutputs
1111
from vllm.v1.metrics.stats import SchedulerStats
12-
from vllm.v1.outputs import ModelRunnerOutput
12+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
1313
from vllm.v1.request import Request, RequestStatus
1414

1515

@@ -61,6 +61,14 @@ def update_from_output(
6161
"""
6262
raise NotImplementedError
6363

64+
@abstractmethod
65+
def update_draft_token_ids(
66+
self,
67+
draft_token_ids: "DraftTokenIds",
68+
) -> None:
69+
"""Update the draft token ids for the scheduled requests."""
70+
raise NotImplementedError
71+
6472
@abstractmethod
6573
def add_request(self, request: "Request") -> None:
6674
"""Add a new request to the scheduler's internal queue.

vllm/v1/core/sched/scheduler.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
EngineCoreOutputs)
3131
from vllm.v1.kv_cache_interface import KVCacheConfig
3232
from vllm.v1.metrics.stats import SchedulerStats
33-
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
33+
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
3434
from vllm.v1.request import Request, RequestStatus
3535
from vllm.v1.spec_decode.metrics import SpecDecodingStats
3636
from vllm.v1.structured_output import StructuredOutputManager
@@ -141,7 +141,6 @@ def __init__(
141141
cache_size=encoder_cache_size)
142142

143143
speculative_config = vllm_config.speculative_config
144-
145144
self.use_eagle = False
146145
self.num_spec_tokens = self.num_lookahead_tokens = 0
147146
if speculative_config:
@@ -760,7 +759,6 @@ def update_from_output(
760759
model_runner_output: ModelRunnerOutput,
761760
) -> dict[int, EngineCoreOutputs]:
762761
sampled_token_ids = model_runner_output.sampled_token_ids
763-
spec_token_ids = model_runner_output.spec_token_ids
764762
logprobs = model_runner_output.logprobs
765763
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
766764
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@@ -845,20 +843,9 @@ def update_from_output(
845843
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
846844
req_id, new_token_ids)
847845

848-
# spec_token_ids comes from the model runner output
849846
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
850847
request.num_nans_in_logits = num_nans_in_logits[req_id]
851848

852-
# Add newly generated spec token ids to the request.
853-
if spec_token_ids is not None:
854-
if self.structured_output_manager.should_advance(request):
855-
metadata = request.structured_output_request
856-
# Needs to happen after new_token_ids are accepted.
857-
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
858-
spec_token_ids[req_index])
859-
else:
860-
request.spec_token_ids = spec_token_ids[req_index]
861-
862849
# Get prompt logprobs for this request.
863850
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
864851
if new_token_ids or pooler_output is not None \
@@ -963,6 +950,30 @@ def _free_encoder_inputs(self, request: Request) -> None:
963950
self.encoder_cache_manager.free_encoder_input(
964951
request, input_id)
965952

953+
def update_draft_token_ids(
954+
self,
955+
draft_token_ids: DraftTokenIds,
956+
) -> None:
957+
for req_id, spec_token_ids in zip(
958+
draft_token_ids.req_ids,
959+
draft_token_ids.draft_token_ids,
960+
):
961+
request = self.requests.get(req_id)
962+
if request is None or request.is_finished():
963+
# The request may have been finished. Skip.
964+
continue
965+
966+
# Add newly generated spec token ids to the request.
967+
if not spec_token_ids:
968+
# NOTE(woosuk): request.spec_token_ids should be updated.
969+
request.spec_token_ids.clear()
970+
elif self.structured_output_manager.should_advance(request):
971+
metadata = request.structured_output_request
972+
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
973+
spec_token_ids)
974+
else:
975+
request.spec_token_ids = spec_token_ids
976+
966977
def get_request_counts(self) -> tuple[int, int]:
967978
"""Returns (num_running_reqs, num_waiting_reqs)."""
968979
return len(self.running), len(self.waiting)

vllm/v1/engine/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __init__(self,
126126
> 1,
127127
log_stats=self.log_stats,
128128
)
129+
self.use_spec_decode = vllm_config.speculative_config is not None
129130

130131
self.mm_input_cache_server = MultiModalInputCacheServer(
131132
vllm_config.model_config, MULTIMODAL_REGISTRY)
@@ -294,6 +295,13 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
294295
return (engine_core_outputs,
295296
scheduler_output.total_num_scheduled_tokens > 0)
296297

298+
def post_step(self, model_executed: bool) -> None:
299+
if self.use_spec_decode and model_executed:
300+
# Take the draft token ids.
301+
draft_token_ids = self.model_executor.take_draft_token_ids()
302+
if draft_token_ids is not None:
303+
self.scheduler.update_draft_token_ids(draft_token_ids)
304+
297305
def step_with_batch_queue(
298306
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
299307
"""Schedule and execute batches with the batch queue.
@@ -746,6 +754,8 @@ def _process_engine_step(self) -> bool:
746754
# Put EngineCoreOutputs into the output queue.
747755
for output in (outputs.items() if outputs else ()):
748756
self.output_queue.put_nowait(output)
757+
# Post-step hook.
758+
self.post_step(model_executed)
749759

750760
return model_executed
751761

vllm/v1/executor/abstract.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from concurrent.futures import Future
5-
from typing import Callable, Union
5+
from typing import Callable, Optional, Union
66

77
import torch
88
import torch.distributed as dist
@@ -14,7 +14,7 @@
1414
from vllm.executor.uniproc_executor import ( # noqa
1515
UniProcExecutor as UniProcExecutorV0)
1616
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
17-
from vllm.v1.outputs import ModelRunnerOutput
17+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
1818

1919
FailureCallback = Callable[[], None]
2020

@@ -88,6 +88,10 @@ def execute_model(
8888
args=(scheduler_output, ))
8989
return output[0]
9090

91+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
92+
output = self.collective_rpc("take_draft_token_ids")
93+
return output[0]
94+
9195
@property
9296
def max_concurrent_batches(self) -> int:
9397
return 1

vllm/v1/executor/multiproc_executor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
get_loopback_ip, get_mp_context, get_open_port,
3434
set_process_title)
3535
from vllm.v1.executor.abstract import Executor, FailureCallback
36-
from vllm.v1.outputs import ModelRunnerOutput
36+
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
3737
from vllm.worker.worker_base import WorkerWrapperBase
3838

3939
logger = init_logger(__name__)
@@ -191,6 +191,12 @@ def execute_model(
191191
outputs, self.output_rank)
192192
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
193193

194+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
195+
# OPTIMIZATION: Get output only from a single worker (output_rank)
196+
outputs = self.collective_rpc("take_draft_token_ids",
197+
unique_reply_rank=self.output_rank)
198+
return outputs[0]
199+
194200
def collective_rpc(self,
195201
method: Union[str, Callable],
196202
timeout: Optional[float] = None,

0 commit comments

Comments
 (0)