2
2
import torch
3
3
4
4
from vllm .sequence import SamplingParams , SequenceData , SequenceGroupMetadata
5
- from vllm .worker .worker import Worker
5
+ from vllm .worker .model_runner import ModelRunner
6
6
7
7
8
- def test_worker_prepare_inputs_for_prompt ():
9
- worker = Worker (None , None , None )
10
- worker .block_size = 16
8
+ def test_prepare_prompt ():
9
+ model_runner = ModelRunner (None , None , None )
10
+ model_runner .set_block_size (16 )
11
+
11
12
batch_size = random .randint (1 , 256 )
12
13
prompt_lens = []
13
14
seq_group_metadata_list = []
14
15
for i in range (batch_size ):
15
16
# make sure all tokens fit into one block
16
- prompt_len = i % (worker .block_size - 1 ) + 1
17
+ prompt_len = i % (model_runner .block_size - 1 ) + 1
17
18
prompt_lens .append (prompt_len )
18
19
seq_data = list (range (prompt_len ))
19
20
seq_group_metadata_list .append (
@@ -24,19 +25,23 @@ def test_worker_prepare_inputs_for_prompt():
24
25
sampling_params = SamplingParams (temperature = 0 ),
25
26
block_tables = {0 : [1 ]},
26
27
))
28
+
27
29
expected_selected_token_indices = []
28
30
selected_token_start_idx = 0
29
31
max_seq_len = max (prompt_lens )
30
32
for prompt_len in prompt_lens :
31
33
expected_selected_token_indices .append (selected_token_start_idx +
32
34
prompt_len - 1 )
33
35
selected_token_start_idx += max_seq_len
34
- input_tokens , input_positions , input_metadata = worker . _prepare_inputs (
36
+ input_tokens , input_positions , _ = model_runner . _prepare_prompt (
35
37
seq_group_metadata_list )
36
- assert input_tokens .shape == input_positions .shape == (batch_size ,
37
- max_seq_len )
38
+ sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
39
+ prompt_lens )
40
+ assert input_tokens .shape == (batch_size , max_seq_len )
41
+ assert input_positions .shape == (batch_size , max_seq_len )
38
42
torch .testing .assert_close (input_tokens , input_positions )
39
- actual = input_metadata .selected_token_indices
43
+
44
+ actual = sampling_metadata .selected_token_indices
40
45
expected = torch .tensor (expected_selected_token_indices ,
41
46
device = actual .device ,
42
47
dtype = actual .dtype )
0 commit comments