Skip to content

Commit cd3aa15

Browse files
authored
Fix broken worker test (#1900)
1 parent 9b29497 commit cd3aa15

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

tests/worker/test_worker.py renamed to tests/worker/test_model_runner.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
import torch
33

44
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
5-
from vllm.worker.worker import Worker
5+
from vllm.worker.model_runner import ModelRunner
66

77

8-
def test_worker_prepare_inputs_for_prompt():
9-
worker = Worker(None, None, None)
10-
worker.block_size = 16
8+
def test_prepare_prompt():
9+
model_runner = ModelRunner(None, None, None)
10+
model_runner.set_block_size(16)
11+
1112
batch_size = random.randint(1, 256)
1213
prompt_lens = []
1314
seq_group_metadata_list = []
1415
for i in range(batch_size):
1516
# make sure all tokens fit into one block
16-
prompt_len = i % (worker.block_size - 1) + 1
17+
prompt_len = i % (model_runner.block_size - 1) + 1
1718
prompt_lens.append(prompt_len)
1819
seq_data = list(range(prompt_len))
1920
seq_group_metadata_list.append(
@@ -24,19 +25,23 @@ def test_worker_prepare_inputs_for_prompt():
2425
sampling_params=SamplingParams(temperature=0),
2526
block_tables={0: [1]},
2627
))
28+
2729
expected_selected_token_indices = []
2830
selected_token_start_idx = 0
2931
max_seq_len = max(prompt_lens)
3032
for prompt_len in prompt_lens:
3133
expected_selected_token_indices.append(selected_token_start_idx +
3234
prompt_len - 1)
3335
selected_token_start_idx += max_seq_len
34-
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
36+
input_tokens, input_positions, _ = model_runner._prepare_prompt(
3537
seq_group_metadata_list)
36-
assert input_tokens.shape == input_positions.shape == (batch_size,
37-
max_seq_len)
38+
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
39+
prompt_lens)
40+
assert input_tokens.shape == (batch_size, max_seq_len)
41+
assert input_positions.shape == (batch_size, max_seq_len)
3842
torch.testing.assert_close(input_tokens, input_positions)
39-
actual = input_metadata.selected_token_indices
43+
44+
actual = sampling_metadata.selected_token_indices
4045
expected = torch.tensor(expected_selected_token_indices,
4146
device=actual.device,
4247
dtype=actual.dtype)

0 commit comments

Comments
 (0)