|
| 1 | +# pylint: disable=protected-access |
| 2 | +import random |
| 3 | +import torch |
| 4 | + |
| 5 | +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata |
| 6 | +from vllm.worker.worker import Worker |
| 7 | + |
| 8 | + |
| 9 | +def test_worker_prepare_inputs_for_prompt(): |
| 10 | + worker = Worker(None, None, None) |
| 11 | + worker.block_size = 16 |
| 12 | + batch_size = random.randint(1, 256) |
| 13 | + prompt_lens = [] |
| 14 | + seq_group_metadata_list = [] |
| 15 | + for i in range(batch_size): |
| 16 | + # make sure all tokens fit into one block |
| 17 | + prompt_len = i % (worker.block_size - 1) + 1 |
| 18 | + prompt_lens.append(prompt_len) |
| 19 | + seq_data = list(range(prompt_len)) |
| 20 | + seq_group_metadata_list.append( |
| 21 | + SequenceGroupMetadata( |
| 22 | + request_id=f"test_{i}", |
| 23 | + is_prompt=True, |
| 24 | + seq_data={0: SequenceData(seq_data)}, |
| 25 | + sampling_params=SamplingParams(temperature=0), |
| 26 | + block_tables={0: [1]}, |
| 27 | + )) |
| 28 | + expected_selected_token_indices = [] |
| 29 | + selected_token_start_idx = 0 |
| 30 | + max_seq_len = max(prompt_lens) |
| 31 | + for prompt_len in prompt_lens: |
| 32 | + expected_selected_token_indices.append(selected_token_start_idx + |
| 33 | + prompt_len - 1) |
| 34 | + selected_token_start_idx += max_seq_len |
| 35 | + input_tokens, input_positions, input_metadata = worker._prepare_inputs( |
| 36 | + seq_group_metadata_list) |
| 37 | + assert input_tokens.shape == input_positions.shape == (batch_size, |
| 38 | + max_seq_len) |
| 39 | + torch.testing.assert_close(input_tokens, input_positions) |
| 40 | + actual = input_metadata.selected_token_indices |
| 41 | + expected = torch.tensor(expected_selected_token_indices, |
| 42 | + device=actual.device, |
| 43 | + dtype=actual.dtype) |
| 44 | + torch.testing.assert_close(actual, expected) |
0 commit comments