5
5
import torch
6
6
7
7
from vllm .model_executor .utils import set_random_seed
8
- from vllm .sequence import SamplerOutput
8
+ from vllm .sequence import ExecuteModelRequest , SamplerOutput
9
9
from vllm .spec_decode .multi_step_worker import MultiStepWorker
10
10
from vllm .spec_decode .top1_proposer import Top1Proposer
11
11
from vllm .worker .worker import Worker
12
12
13
13
from .utils import (assert_logprobs_dict_allclose , create_batch ,
14
- create_execute_model_data ,
15
14
create_seq_group_metadata_from_prompts , create_worker ,
16
15
patch_execute_model_with_seeds , zero_kv_cache )
17
16
@@ -105,31 +104,32 @@ def test_same_output_for_single_step():
105
104
106
105
final_prompt_lens = [len (prompt ) + num_steps for prompt in prompts ]
107
106
108
- multi_step_execute_model_data = create_execute_model_data (
109
- seq_group_metadata_list = create_seq_group_metadata_from_prompts (
110
- prompts ,
111
- num_gpu_blocks ,
112
- block_size ,
113
- final_prompt_lens = final_prompt_lens ))
114
-
115
- single_step_execute_model_data = create_execute_model_data (
116
- seq_group_metadata_list = create_seq_group_metadata_from_prompts (
117
- prompts ,
118
- num_gpu_blocks ,
119
- block_size ,
120
- final_prompt_lens = final_prompt_lens ))
107
+ multi_step_seq_group = create_seq_group_metadata_from_prompts (
108
+ prompts ,
109
+ num_gpu_blocks ,
110
+ block_size ,
111
+ final_prompt_lens = final_prompt_lens )
121
112
122
113
zero_kv_cache (multi_step_worker .cache_engine )
123
114
set_random_seed (seed )
124
115
actual_output , _ = multi_step_worker .sampler_output (
125
- ** multi_step_execute_model_data .to_dict (), sample_len = num_steps )
116
+ execute_model_req = ExecuteModelRequest (
117
+ seq_group_metadata_list = multi_step_seq_group ),
118
+ sample_len = num_steps )
126
119
assert len (actual_output ) == num_steps
127
120
actual_output = actual_output [0 ]
128
121
122
+ single_step_seq_group = create_seq_group_metadata_from_prompts (
123
+ prompts ,
124
+ num_gpu_blocks ,
125
+ block_size ,
126
+ final_prompt_lens = final_prompt_lens )
127
+
129
128
zero_kv_cache (worker .cache_engine )
130
129
set_random_seed (seed )
131
130
expected_output = worker .execute_model (
132
- ** single_step_execute_model_data .to_dict (), )[0 ]
131
+ execute_model_req = ExecuteModelRequest (
132
+ seq_group_metadata_list = single_step_seq_group ))[0 ]
133
133
134
134
actual_token_ids = [
135
135
output .samples [0 ].output_token for output in actual_output
@@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
193
193
worker .execute_model = patch_execute_model_with_seeds (worker , rand_seeds )
194
194
195
195
continuations = [[1 ] for _ in prompts ]
196
- execute_model_data = create_execute_model_data (
197
- create_seq_group_metadata_from_prompts (
198
- prompts ,
199
- num_gpu_blocks ,
200
- block_size ,
201
- continuations = continuations ,
202
- final_prompt_lens = final_prompt_lens ), )
196
+ seq_group_metadata_list = create_seq_group_metadata_from_prompts (
197
+ prompts ,
198
+ num_gpu_blocks ,
199
+ block_size ,
200
+ continuations = continuations ,
201
+ final_prompt_lens = final_prompt_lens )
203
202
204
203
# Run multi-step.
205
204
zero_kv_cache (multi_step_worker .cache_engine )
206
205
set_random_seed (seed )
207
206
multi_step_output , _ = multi_step_worker .sampler_output (
208
- ** execute_model_data .to_dict (), sample_len = num_steps )
207
+ execute_model_req = ExecuteModelRequest (
208
+ seq_group_metadata_list = seq_group_metadata_list ),
209
+ sample_len = num_steps )
209
210
210
211
# Run single-step repeatedly.
211
212
zero_kv_cache (worker .cache_engine )
@@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
215
216
216
217
for _ in multi_step_output :
217
218
218
- execute_model_data = create_execute_model_data (
219
- create_seq_group_metadata_from_prompts (
220
- prompts ,
221
- num_gpu_blocks ,
222
- block_size ,
223
- continuations = continuations ,
224
- final_prompt_lens = final_prompt_lens ))
219
+ seq_group_metadata_list = create_seq_group_metadata_from_prompts (
220
+ prompts ,
221
+ num_gpu_blocks ,
222
+ block_size ,
223
+ continuations = continuations ,
224
+ final_prompt_lens = final_prompt_lens )
225
225
226
226
single_step_output .extend (
227
- worker .execute_model (** execute_model_data .to_dict (), ))
227
+ worker .execute_model (execute_model_req = ExecuteModelRequest (
228
+ seq_group_metadata_list = seq_group_metadata_list )))
228
229
229
230
# Append output tokens to new sequence data.
230
231
for i , seq_group_output in enumerate (single_step_output [- 1 ]):
@@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
304
305
) for _ in range (k )
305
306
], True
306
307
307
- execute_model_data , _ , _ = create_batch (batch_size , k )
308
+ seq_group_metadata_list , _ , _ = create_batch (batch_size , k )
308
309
309
- proposals = proposer .get_proposals (
310
- ** execute_model_data .to_dict (),
311
- proposal_len = k ,
312
- )
310
+ proposals = proposer .get_proposals (execute_model_req = ExecuteModelRequest (
311
+ seq_group_metadata_list = seq_group_metadata_list ,
312
+ num_lookahead_slots = k ), )
313
313
314
314
assert torch .is_tensor (proposals .proposal_token_ids )
315
315
assert torch .is_tensor (proposals .proposal_probs )
@@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
340
340
max_proposal_len = prompt_len + k - 1 ,
341
341
)
342
342
343
- execute_model_data , _ , _ = create_batch (batch_size ,
344
- k ,
345
- prompt_len = prompt_len )
343
+ seq_group_metadata_list , _ , _ = create_batch (batch_size ,
344
+ k ,
345
+ prompt_len = prompt_len )
346
346
347
- proposals = proposer .get_proposals (
348
- ** execute_model_data .to_dict (),
349
- proposal_len = k ,
350
- )
347
+ proposals = proposer .get_proposals (execute_model_req = ExecuteModelRequest (
348
+ seq_group_metadata_list = seq_group_metadata_list ,
349
+ num_lookahead_slots = k ), )
351
350
352
351
assert torch .is_tensor (proposals .proposal_token_ids )
353
352
assert torch .is_tensor (proposals .proposal_probs )
@@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
409
408
) for _ in range (k )
410
409
], True
411
410
412
- execute_model_data , _ , _ = create_batch (
411
+ seq_group_metadata_list , _ , _ = create_batch (
413
412
batch_size ,
414
413
k ,
415
414
prompt_len = prompt_len ,
416
415
prev_output_token_len = prev_output_token_len ,
417
416
)
418
417
419
- proposals = proposer .get_proposals (
420
- ** execute_model_data .to_dict (),
421
- proposal_len = k ,
422
- )
418
+ proposals = proposer .get_proposals (execute_model_req = ExecuteModelRequest (
419
+ seq_group_metadata_list = seq_group_metadata_list ,
420
+ num_lookahead_slots = k ), )
423
421
424
422
assert torch .is_tensor (proposals .proposal_token_ids )
425
423
assert torch .is_tensor (proposals .proposal_probs )
0 commit comments