Skip to content

Commit bc8ad68

Browse files
authored
[Misc][Refactor] Introduce ExecuteModelData (#4540)
1 parent 344bf7c commit bc8ad68

23 files changed

+359
-515
lines changed

tests/spec_decode/test_multi_step_worker.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import torch
66

77
from vllm.model_executor.utils import set_random_seed
8-
from vllm.sequence import SamplerOutput
8+
from vllm.sequence import ExecuteModelRequest, SamplerOutput
99
from vllm.spec_decode.multi_step_worker import MultiStepWorker
1010
from vllm.spec_decode.top1_proposer import Top1Proposer
1111
from vllm.worker.worker import Worker
1212

1313
from .utils import (assert_logprobs_dict_allclose, create_batch,
14-
create_execute_model_data,
1514
create_seq_group_metadata_from_prompts, create_worker,
1615
patch_execute_model_with_seeds, zero_kv_cache)
1716

@@ -105,31 +104,32 @@ def test_same_output_for_single_step():
105104

106105
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
107106

108-
multi_step_execute_model_data = create_execute_model_data(
109-
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
110-
prompts,
111-
num_gpu_blocks,
112-
block_size,
113-
final_prompt_lens=final_prompt_lens))
114-
115-
single_step_execute_model_data = create_execute_model_data(
116-
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
117-
prompts,
118-
num_gpu_blocks,
119-
block_size,
120-
final_prompt_lens=final_prompt_lens))
107+
multi_step_seq_group = create_seq_group_metadata_from_prompts(
108+
prompts,
109+
num_gpu_blocks,
110+
block_size,
111+
final_prompt_lens=final_prompt_lens)
121112

122113
zero_kv_cache(multi_step_worker.cache_engine)
123114
set_random_seed(seed)
124115
actual_output, _ = multi_step_worker.sampler_output(
125-
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
116+
execute_model_req=ExecuteModelRequest(
117+
seq_group_metadata_list=multi_step_seq_group),
118+
sample_len=num_steps)
126119
assert len(actual_output) == num_steps
127120
actual_output = actual_output[0]
128121

122+
single_step_seq_group = create_seq_group_metadata_from_prompts(
123+
prompts,
124+
num_gpu_blocks,
125+
block_size,
126+
final_prompt_lens=final_prompt_lens)
127+
129128
zero_kv_cache(worker.cache_engine)
130129
set_random_seed(seed)
131130
expected_output = worker.execute_model(
132-
**single_step_execute_model_data.to_dict(), )[0]
131+
execute_model_req=ExecuteModelRequest(
132+
seq_group_metadata_list=single_step_seq_group))[0]
133133

134134
actual_token_ids = [
135135
output.samples[0].output_token for output in actual_output
@@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
193193
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
194194

195195
continuations = [[1] for _ in prompts]
196-
execute_model_data = create_execute_model_data(
197-
create_seq_group_metadata_from_prompts(
198-
prompts,
199-
num_gpu_blocks,
200-
block_size,
201-
continuations=continuations,
202-
final_prompt_lens=final_prompt_lens), )
196+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
197+
prompts,
198+
num_gpu_blocks,
199+
block_size,
200+
continuations=continuations,
201+
final_prompt_lens=final_prompt_lens)
203202

204203
# Run multi-step.
205204
zero_kv_cache(multi_step_worker.cache_engine)
206205
set_random_seed(seed)
207206
multi_step_output, _ = multi_step_worker.sampler_output(
208-
**execute_model_data.to_dict(), sample_len=num_steps)
207+
execute_model_req=ExecuteModelRequest(
208+
seq_group_metadata_list=seq_group_metadata_list),
209+
sample_len=num_steps)
209210

210211
# Run single-step repeatedly.
211212
zero_kv_cache(worker.cache_engine)
@@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
215216

216217
for _ in multi_step_output:
217218

218-
execute_model_data = create_execute_model_data(
219-
create_seq_group_metadata_from_prompts(
220-
prompts,
221-
num_gpu_blocks,
222-
block_size,
223-
continuations=continuations,
224-
final_prompt_lens=final_prompt_lens))
219+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
220+
prompts,
221+
num_gpu_blocks,
222+
block_size,
223+
continuations=continuations,
224+
final_prompt_lens=final_prompt_lens)
225225

226226
single_step_output.extend(
227-
worker.execute_model(**execute_model_data.to_dict(), ))
227+
worker.execute_model(execute_model_req=ExecuteModelRequest(
228+
seq_group_metadata_list=seq_group_metadata_list)))
228229

229230
# Append output tokens to new sequence data.
230231
for i, seq_group_output in enumerate(single_step_output[-1]):
@@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
304305
) for _ in range(k)
305306
], True
306307

307-
execute_model_data, _, _ = create_batch(batch_size, k)
308+
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
308309

309-
proposals = proposer.get_proposals(
310-
**execute_model_data.to_dict(),
311-
proposal_len=k,
312-
)
310+
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
311+
seq_group_metadata_list=seq_group_metadata_list,
312+
num_lookahead_slots=k), )
313313

314314
assert torch.is_tensor(proposals.proposal_token_ids)
315315
assert torch.is_tensor(proposals.proposal_probs)
@@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
340340
max_proposal_len=prompt_len + k - 1,
341341
)
342342

343-
execute_model_data, _, _ = create_batch(batch_size,
344-
k,
345-
prompt_len=prompt_len)
343+
seq_group_metadata_list, _, _ = create_batch(batch_size,
344+
k,
345+
prompt_len=prompt_len)
346346

347-
proposals = proposer.get_proposals(
348-
**execute_model_data.to_dict(),
349-
proposal_len=k,
350-
)
347+
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
348+
seq_group_metadata_list=seq_group_metadata_list,
349+
num_lookahead_slots=k), )
351350

352351
assert torch.is_tensor(proposals.proposal_token_ids)
353352
assert torch.is_tensor(proposals.proposal_probs)
@@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
409408
) for _ in range(k)
410409
], True
411410

412-
execute_model_data, _, _ = create_batch(
411+
seq_group_metadata_list, _, _ = create_batch(
413412
batch_size,
414413
k,
415414
prompt_len=prompt_len,
416415
prev_output_token_len=prev_output_token_len,
417416
)
418417

419-
proposals = proposer.get_proposals(
420-
**execute_model_data.to_dict(),
421-
proposal_len=k,
422-
)
418+
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
419+
seq_group_metadata_list=seq_group_metadata_list,
420+
num_lookahead_slots=k), )
423421

424422
assert torch.is_tensor(proposals.proposal_token_ids)
425423
assert torch.is_tensor(proposals.proposal_probs)

tests/spec_decode/test_ngram_worker.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import torch
22

3+
from vllm.sequence import ExecuteModelRequest
34
from vllm.spec_decode.ngram_worker import NGramWorker
45
from vllm.spec_decode.top1_proposer import Top1Proposer
56

6-
from .utils import (create_execute_model_data,
7-
create_seq_group_metadata_from_prompts, create_worker)
7+
from .utils import create_seq_group_metadata_from_prompts, create_worker
88

99

1010
def test_ngram_algo_correctness_for_single_no_match():
@@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
4444

4545
proposal_len = 5
4646
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
47-
ngram_sampler_output_data = create_execute_model_data(
48-
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
49-
prompts,
50-
num_gpu_blocks,
51-
block_size,
52-
final_prompt_lens=final_prompt_lens))
53-
54-
proposals = proposer.get_proposals(
55-
**ngram_sampler_output_data.to_dict(),
56-
proposal_len=proposal_len,
57-
)
47+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
48+
prompts,
49+
num_gpu_blocks,
50+
block_size,
51+
final_prompt_lens=final_prompt_lens)
52+
53+
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
54+
seq_group_metadata_list=seq_group_metadata_list,
55+
num_lookahead_slots=proposal_len), )
5856

5957
assert torch.is_tensor(proposals.proposal_token_ids)
6058
assert torch.is_tensor(proposals.proposal_probs)
@@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
113111

114112
proposal_len = 5
115113
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
116-
ngram_sampler_output_data = create_execute_model_data(
117-
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
118-
prompts,
119-
num_gpu_blocks,
120-
block_size,
121-
final_prompt_lens=final_prompt_lens))
122-
123-
proposals = proposer.get_proposals(
124-
**ngram_sampler_output_data.to_dict(),
125-
proposal_len=proposal_len,
126-
)
114+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
115+
prompts,
116+
num_gpu_blocks,
117+
block_size,
118+
final_prompt_lens=final_prompt_lens)
119+
120+
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
121+
seq_group_metadata_list=seq_group_metadata_list,
122+
num_lookahead_slots=proposal_len), )
127123

128124
assert torch.is_tensor(proposals.proposal_token_ids)
129125
assert torch.is_tensor(proposals.proposal_probs)
@@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
185181

186182
proposal_len = 5
187183
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
188-
ngram_sampler_output_data = create_execute_model_data(
189-
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
190-
prompts,
191-
num_gpu_blocks,
192-
block_size,
193-
final_prompt_lens=final_prompt_lens))
194-
195-
proposals = proposer.get_proposals(
196-
**ngram_sampler_output_data.to_dict(),
197-
proposal_len=proposal_len,
198-
)
184+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
185+
prompts,
186+
num_gpu_blocks,
187+
block_size,
188+
final_prompt_lens=final_prompt_lens)
189+
190+
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
191+
seq_group_metadata_list=seq_group_metadata_list,
192+
num_lookahead_slots=proposal_len), )
199193

200194
assert torch.is_tensor(proposals.proposal_token_ids)
201195
assert torch.is_tensor(proposals.proposal_probs)

0 commit comments

Comments
 (0)