14
14
from vllm .v1 .core .sched .scheduler import Scheduler
15
15
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
16
16
KVCacheGroupSpec )
17
- from vllm .v1 .outputs import ModelRunnerOutput
17
+ from vllm .v1 .outputs import DraftTokenIds , ModelRunnerOutput
18
18
from vllm .v1 .request import Request , RequestStatus
19
19
from vllm .v1 .structured_output import StructuredOutputManager
20
20
from vllm .v1 .structured_output .request import StructuredOutputRequest
@@ -158,7 +158,6 @@ def test_schedule_partial_requests():
158
158
# Only the first request has a sampled token id because
159
159
# the rest requests are still being prefilled.
160
160
sampled_token_ids = [[0 ], [], []],
161
- spec_token_ids = None ,
162
161
logprobs = None ,
163
162
prompt_logprobs_dict = {},
164
163
pooler_output = [],
@@ -209,7 +208,6 @@ def test_no_mm_input_chunking():
209
208
req_ids = [request .request_id for request in requests ],
210
209
req_id_to_index = req_to_index ,
211
210
sampled_token_ids = [[] for _ in range (len (requests ))],
212
- spec_token_ids = None ,
213
211
logprobs = None ,
214
212
prompt_logprobs_dict = {},
215
213
pooler_output = [],
@@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
273
271
req_ids = [request .request_id for request in requests ],
274
272
req_id_to_index = req_to_index ,
275
273
sampled_token_ids = [[] for _ in range (len (requests ))],
276
- spec_token_ids = None ,
277
274
logprobs = None ,
278
275
prompt_logprobs_dict = {},
279
276
pooler_output = [],
@@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
298
295
req_ids = [request .request_id for request in requests ],
299
296
req_id_to_index = req_to_index ,
300
297
sampled_token_ids = [[0 ], [0 ]] + [[] for _ in range (len (requests ) - 2 )],
301
- spec_token_ids = None ,
302
298
logprobs = None ,
303
299
prompt_logprobs_dict = {},
304
300
pooler_output = [],
@@ -355,7 +351,6 @@ def test_stop_via_update_from_output():
355
351
sampled_token_ids = [[EOS_TOKEN_ID ],
356
352
[10 ,
357
353
11 ]], # First request hits EOS, second continues
358
- spec_token_ids = None ,
359
354
logprobs = None ,
360
355
prompt_logprobs_dict = {},
361
356
pooler_output = [])
@@ -409,7 +404,6 @@ def test_stop_via_update_from_output():
409
404
},
410
405
sampled_token_ids = [[10 , 42 , 12 ],
411
406
[13 , 14 ]], # First request hits stop token
412
- spec_token_ids = None ,
413
407
logprobs = None ,
414
408
prompt_logprobs_dict = {},
415
409
pooler_output = [])
@@ -462,7 +456,6 @@ def test_stop_via_update_from_output():
462
456
},
463
457
sampled_token_ids = [[10 , 11 , 12 ],
464
458
[13 ]], # First request exceeds max_tokens
465
- spec_token_ids = None ,
466
459
logprobs = None ,
467
460
prompt_logprobs_dict = {},
468
461
pooler_output = [])
@@ -505,7 +498,6 @@ def test_stop_via_update_from_output():
505
498
req_ids = [requests [0 ].request_id ],
506
499
req_id_to_index = {requests [0 ].request_id : 0 },
507
500
sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
508
- spec_token_ids = None ,
509
501
logprobs = None ,
510
502
prompt_logprobs_dict = {},
511
503
pooler_output = [])
@@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
554
546
req_ids = [requests [0 ].request_id ],
555
547
req_id_to_index = {requests [0 ].request_id : 0 },
556
548
sampled_token_ids = [[0 ]],
557
- spec_token_ids = None ,
558
549
logprobs = None ,
559
550
prompt_logprobs_dict = {},
560
551
pooler_output = [],
@@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
572
563
req_ids = [requests [1 ].request_id ],
573
564
req_id_to_index = {requests [1 ].request_id : 0 },
574
565
sampled_token_ids = [[0 ]],
575
- spec_token_ids = None ,
576
566
logprobs = None ,
577
567
prompt_logprobs_dict = {},
578
568
pooler_output = [],
@@ -608,7 +598,6 @@ def test_preempt_during_execution():
608
598
req_ids = [requests [0 ].request_id ],
609
599
req_id_to_index = {requests [0 ].request_id : 0 },
610
600
sampled_token_ids = [[0 ]],
611
- spec_token_ids = None ,
612
601
logprobs = None ,
613
602
prompt_logprobs_dict = {},
614
603
pooler_output = [],
@@ -626,7 +615,6 @@ def test_preempt_during_execution():
626
615
req_ids = [requests [1 ].request_id ],
627
616
req_id_to_index = {requests [1 ].request_id : 0 },
628
617
sampled_token_ids = [[42 ]],
629
- spec_token_ids = None ,
630
618
logprobs = None ,
631
619
prompt_logprobs_dict = {},
632
620
pooler_output = [],
@@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
682
670
req_ids = req_ids ,
683
671
req_id_to_index = req_to_index ,
684
672
sampled_token_ids = [[0 ] for _ in range (len (requests ))],
685
- spec_token_ids = spec_tokens ,
686
673
logprobs = None ,
687
674
prompt_logprobs_dict = {},
688
675
pooler_output = [],
689
676
)
690
677
engine_core_outputs = scheduler .update_from_output (output ,
691
678
model_runner_output )
679
+ draft_token_ids = DraftTokenIds (req_ids , spec_tokens )
680
+ scheduler .update_draft_token_ids (draft_token_ids )
692
681
693
682
for i in range (len (requests )):
694
683
running_req = scheduler .running [i ]
@@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
722
711
req_ids = req_ids ,
723
712
req_id_to_index = req_to_index ,
724
713
sampled_token_ids = output_tokens ,
725
- spec_token_ids = None ,
726
714
logprobs = None ,
727
715
prompt_logprobs_dict = {},
728
716
pooler_output = [],
@@ -851,7 +839,6 @@ def test_kv_connector_basic():
851
839
req_ids = req_ids ,
852
840
req_id_to_index = req_to_index ,
853
841
sampled_token_ids = [[1000 ]] * len (req_ids ),
854
- spec_token_ids = None ,
855
842
logprobs = None ,
856
843
prompt_logprobs_dict = {},
857
844
pooler_output = [],
@@ -898,7 +885,6 @@ def test_kv_connector_basic():
898
885
req_ids = req_ids ,
899
886
req_id_to_index = req_to_index ,
900
887
sampled_token_ids = [[1000 ]] * len (req_ids ),
901
- spec_token_ids = None ,
902
888
logprobs = None ,
903
889
prompt_logprobs_dict = {},
904
890
pooler_output = [],
@@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate():
966
952
req_ids = req_ids ,
967
953
req_id_to_index = req_to_index ,
968
954
sampled_token_ids = [[1000 ]] * len (req_ids ),
969
- spec_token_ids = None ,
970
955
logprobs = None ,
971
956
prompt_logprobs_dict = {},
972
957
pooler_output = [],
@@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption():
1048
1033
req_ids = req_ids ,
1049
1034
req_id_to_index = req_to_index ,
1050
1035
sampled_token_ids = [[1000 ]] * len (req_ids ),
1051
- spec_token_ids = None ,
1052
1036
logprobs = None ,
1053
1037
prompt_logprobs_dict = {},
1054
1038
pooler_output = [],
@@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler):
1142
1126
for i , req in enumerate (scheduler .running )
1143
1127
},
1144
1128
sampled_token_ids = [[1000 ]] * len (scheduler .running ),
1145
- spec_token_ids = None ,
1146
1129
logprobs = None ,
1147
1130
prompt_logprobs_dict = {},
1148
1131
pooler_output = [],
@@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption():
1468
1451
for i , req in enumerate (low_priority_requests )
1469
1452
},
1470
1453
sampled_token_ids = [[100 ] for _ in low_priority_requests ],
1471
- spec_token_ids = None ,
1472
1454
logprobs = None ,
1473
1455
prompt_logprobs_dict = {},
1474
1456
pooler_output = [],
@@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
1541
1523
for i , req in enumerate (low_priority_requests )
1542
1524
},
1543
1525
sampled_token_ids = [[100 ] for _ in low_priority_requests ],
1544
- spec_token_ids = None ,
1545
1526
logprobs = None ,
1546
1527
prompt_logprobs_dict = {},
1547
1528
pooler_output = [],
@@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property():
1783
1764
req_ids = [req .req_id ],
1784
1765
req_id_to_index = {req .req_id : 0 },
1785
1766
sampled_token_ids = [[100 ]],
1786
- spec_token_ids = None ,
1787
1767
logprobs = None ,
1788
1768
prompt_logprobs_dict = {},
1789
1769
pooler_output = [],
0 commit comments