Skip to content

Commit 8efe23f

Browse files
authored
Fix input_metadata.selected_token_indices in worker prepare_inputs (#1546)
1 parent 06458a0 commit 8efe23f

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

tests/worker/test_worker.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# pylint: disable=protected-access
2+
import random
3+
import torch
4+
5+
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
6+
from vllm.worker.worker import Worker
7+
8+
9+
def test_worker_prepare_inputs_for_prompt():
10+
worker = Worker(None, None, None)
11+
worker.block_size = 16
12+
batch_size = random.randint(1, 256)
13+
prompt_lens = []
14+
seq_group_metadata_list = []
15+
for i in range(batch_size):
16+
# make sure all tokens fit into one block
17+
prompt_len = i % (worker.block_size - 1) + 1
18+
prompt_lens.append(prompt_len)
19+
seq_data = list(range(prompt_len))
20+
seq_group_metadata_list.append(
21+
SequenceGroupMetadata(
22+
request_id=f"test_{i}",
23+
is_prompt=True,
24+
seq_data={0: SequenceData(seq_data)},
25+
sampling_params=SamplingParams(temperature=0),
26+
block_tables={0: [1]},
27+
))
28+
expected_selected_token_indices = []
29+
selected_token_start_idx = 0
30+
max_seq_len = max(prompt_lens)
31+
for prompt_len in prompt_lens:
32+
expected_selected_token_indices.append(selected_token_start_idx +
33+
prompt_len - 1)
34+
selected_token_start_idx += max_seq_len
35+
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
36+
seq_group_metadata_list)
37+
assert input_tokens.shape == input_positions.shape == (batch_size,
38+
max_seq_len)
39+
torch.testing.assert_close(input_tokens, input_positions)
40+
actual = input_metadata.selected_token_indices
41+
expected = torch.tensor(expected_selected_token_indices,
42+
device=actual.device,
43+
dtype=actual.dtype)
44+
torch.testing.assert_close(actual, expected)

vllm/worker/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,14 @@ def _prepare_inputs(
211211
context_lens: List[int] = []
212212
generation_block_tables: List[List[int]] = []
213213
max_seq_len = max(prompt_lens) if prompt_lens else 1
214-
for seq_group_metadata in seq_group_metadata_list:
214+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
215215
if seq_group_metadata.is_prompt:
216216
# We need to do this in this loop as we need to know max_seq_len
217217
assert len(
218218
seq_ids) == 1, "Prompt input should have only one seq."
219219
sampling_params = seq_group_metadata.sampling_params
220+
assert len(prompt_lens) == len(seq_group_metadata_list)
221+
prompt_len = prompt_lens[i]
220222
if sampling_params.prompt_logprobs is not None:
221223
selected_token_indices.extend(
222224
range(selected_token_start_idx,

0 commit comments

Comments
 (0)