22import torch
33
44from vllm .sequence import SamplingParams , SequenceData , SequenceGroupMetadata
5- from vllm .worker .worker import Worker
5+ from vllm .worker .model_runner import ModelRunner
66
77
8- def test_worker_prepare_inputs_for_prompt ():
9- worker = Worker (None , None , None )
10- worker .block_size = 16
8+ def test_prepare_prompt ():
9+ model_runner = ModelRunner (None , None , None )
10+ model_runner .set_block_size (16 )
11+
1112 batch_size = random .randint (1 , 256 )
1213 prompt_lens = []
1314 seq_group_metadata_list = []
1415 for i in range (batch_size ):
1516 # make sure all tokens fit into one block
16- prompt_len = i % (worker .block_size - 1 ) + 1
17+ prompt_len = i % (model_runner .block_size - 1 ) + 1
1718 prompt_lens .append (prompt_len )
1819 seq_data = list (range (prompt_len ))
1920 seq_group_metadata_list .append (
@@ -24,19 +25,23 @@ def test_worker_prepare_inputs_for_prompt():
2425 sampling_params = SamplingParams (temperature = 0 ),
2526 block_tables = {0 : [1 ]},
2627 ))
28+
2729 expected_selected_token_indices = []
2830 selected_token_start_idx = 0
2931 max_seq_len = max (prompt_lens )
3032 for prompt_len in prompt_lens :
3133 expected_selected_token_indices .append (selected_token_start_idx +
3234 prompt_len - 1 )
3335 selected_token_start_idx += max_seq_len
34- input_tokens , input_positions , input_metadata = worker . _prepare_inputs (
36+ input_tokens , input_positions , _ = model_runner . _prepare_prompt (
3537 seq_group_metadata_list )
36- assert input_tokens .shape == input_positions .shape == (batch_size ,
37- max_seq_len )
38+ sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
39+ prompt_lens )
40+ assert input_tokens .shape == (batch_size , max_seq_len )
41+ assert input_positions .shape == (batch_size , max_seq_len )
3842 torch .testing .assert_close (input_tokens , input_positions )
39- actual = input_metadata .selected_token_indices
43+
44+ actual = sampling_metadata .selected_token_indices
4045 expected = torch .tensor (expected_selected_token_indices ,
4146 device = actual .device ,
4247 dtype = actual .dtype )
0 commit comments