11import random
2- from typing import Tuple
2+ from typing import Tuple , List
33from unittest .mock import patch
44
55import pytest
66import torch
77from transformers import GenerationConfig , GenerationMixin
8+ from typing import Optional
89
910from vllm .model_executor .layers .sampler import Sampler
1011from vllm .model_executor .utils import set_random_seed
@@ -46,15 +47,13 @@ def _prepare_test(
4647]
4748
4849
49- @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
50- @pytest .mark .parametrize ("device" , CUDA_DEVICES )
51- def test_sampler_all_greedy (seed : int , device : str ):
52- set_random_seed (seed )
53- torch .set_default_device (device )
54- batch_size = random .randint (1 , 256 )
55- input_tensor , fake_logits , sampler , model_runner = _prepare_test (
56- batch_size )
57-
50+ def _do_sample (
51+ batch_size : int ,
52+ input_tensor : torch .Tensor ,
53+ sampler : MockLogitsSampler ,
54+ model_runner : ModelRunner ,
55+ sampling_params : SamplingParams ,
56+ ):
5857 seq_group_metadata_list = []
5958 prompt_lens = []
6059 for i in range (batch_size ):
@@ -63,17 +62,31 @@ def test_sampler_all_greedy(seed: int, device: str):
6362 request_id = f"test_{ i } " ,
6463 is_prompt = True ,
6564 seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
66- sampling_params = SamplingParams ( temperature = 0 , ) ,
65+ sampling_params = sampling_params ,
6766 block_tables = {0 : [1 ]},
6867 ))
6968 prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
7069
7170 sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
7271 prompt_lens ,
7372 subquery_lens = prompt_lens )
74- sampler_output = sampler (embedding = None ,
75- hidden_states = input_tensor ,
76- sampling_metadata = sampling_metadata )
73+ return sampler (embedding = None ,
74+ hidden_states = input_tensor ,
75+ sampling_metadata = sampling_metadata )
76+
77+
78+ @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
79+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
80+ def test_sampler_all_greedy (seed : int , device : str ):
81+ set_random_seed (seed )
82+ torch .set_default_device (device )
83+ batch_size = random .randint (1 , 256 )
84+ input_tensor , fake_logits , sampler , model_runner = _prepare_test (
85+ batch_size )
86+
87+ sampling_params = SamplingParams (temperature = 0 )
88+ sampler_output = _do_sample (batch_size , input_tensor , sampler ,
89+ model_runner , sampling_params )
7790 expected = torch .argmax (fake_logits , dim = - 1 )
7891 for i , sequence_output in enumerate (sampler_output ):
7992 for nth_output in sequence_output .samples :
@@ -94,35 +107,72 @@ def test_sampler_all_random(seed: int, device: str):
94107 for i in range (batch_size ):
95108 fake_logits [i , i ] = 1e2
96109
97- seq_group_metadata_list = []
98- prompt_lens = []
110+ sampling_params = SamplingParams (
111+ temperature = 1.0 ,
112+ n = random .randint (1 , 10 ),
113+ )
114+ sampler_output = _do_sample (batch_size , input_tensor , sampler ,
115+ model_runner , sampling_params )
116+
117+ for i , sequence_output in enumerate (sampler_output ):
118+ for nth_output in sequence_output .samples :
119+ assert nth_output .output_token == i
120+
121+ del model_runner
122+
123+
124+ @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
125+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
126+ def test_sampler_all_random_seed (seed : int , device : str ):
127+ set_random_seed (seed )
128+ torch .set_default_device (device )
129+ batch_size = random .randint (1 , 256 )
130+ input_tensor , fake_logits , sampler , model_runner = _prepare_test (
131+ batch_size )
132+
99133 for i in range (batch_size ):
100- seq_group_metadata_list .append (
101- SequenceGroupMetadata (
102- request_id = f"test_{ i } " ,
103- is_prompt = True ,
104- seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
105- sampling_params = SamplingParams (
106- temperature = 1.0 ,
107- n = random .randint (1 , 10 ),
108- ),
109- block_tables = {0 : [1 ]},
110- ))
111- prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
134+ fake_logits [i , i ] = 1e2
135+
136+ sampling_params = SamplingParams (
137+ temperature = 1.0 ,
138+ n = random .randint (1 , 10 ),
139+ seed = random .randint (0 , 10000 ),
140+ )
141+ sampler_output = _do_sample (batch_size , input_tensor , sampler ,
142+ model_runner , sampling_params )
112143
113- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
114- prompt_lens ,
115- subquery_lens = prompt_lens )
116- sampler_output = sampler (embedding = None ,
117- hidden_states = input_tensor ,
118- sampling_metadata = sampling_metadata )
119144 for i , sequence_output in enumerate (sampler_output ):
120145 for nth_output in sequence_output .samples :
121146 assert nth_output .output_token == i
122147
123148 del model_runner
124149
125150
151+ @pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
152+ @pytest .mark .parametrize ("device" , CUDA_DEVICES )
153+ def test_sampler_all_random_seed_deterministic (seed : int , device : str ):
154+ set_random_seed (seed )
155+ torch .set_default_device (device )
156+ batch_size = random .randint (1 , 256 )
157+ input_tensor , fake_logits , sampler , model_runner = _prepare_test (
158+ batch_size )
159+
160+ sampling_params = SamplingParams (
161+ temperature = 1.0 ,
162+ n = random .randint (1 , 10 ),
163+ seed = random .randint (0 , 10000 ),
164+ )
165+ first_sampler_output = _do_sample (batch_size , input_tensor , sampler ,
166+ model_runner , sampling_params )
167+
168+ second_sampler_output = _do_sample (batch_size , input_tensor , sampler ,
169+ model_runner , sampling_params )
170+
171+ assert first_sampler_output == second_sampler_output
172+
173+ del model_runner
174+
175+
126176@pytest .mark .parametrize ("seed" , RANDOM_SEEDS )
127177@pytest .mark .parametrize ("device" , CUDA_DEVICES )
128178def test_sampler_all_beam (seed : int , device : str ):
@@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
131181 batch_size = random .randint (1 , 256 )
132182 input_tensor , _ , sampler , model_runner = _prepare_test (batch_size )
133183
134- seq_group_metadata_list = []
135- prompt_lens = []
136- for i in range (batch_size ):
137- seq_group_metadata_list .append (
138- SequenceGroupMetadata (
139- request_id = f"test_{ i } " ,
140- is_prompt = True ,
141- seq_data = {0 : SequenceData ([1 , 2 , 3 ])},
142- sampling_params = SamplingParams (
143- temperature = 0 ,
144- best_of = 2 ,
145- use_beam_search = True ,
146- ),
147- block_tables = {0 : [1 ]},
148- ))
149- prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
150-
151- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
152- prompt_lens ,
153- subquery_lens = prompt_lens )
154- sampler (embedding = None ,
155- hidden_states = input_tensor ,
156- sampling_metadata = sampling_metadata )
184+ sampling_params = SamplingParams (
185+ temperature = 0 ,
186+ best_of = 2 ,
187+ use_beam_search = True ,
188+ )
189+ _do_sample (batch_size , input_tensor , sampler , model_runner ,
190+ sampling_params )
157191 # no assertion here as I am not sure how to determine whether
158192 # the outputs are expected - in other words, this just tests
159193 # whether there are no exceptions in the sampler
@@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
171205 batch_size )
172206
173207 seq_group_metadata_list = []
174- expected_tokens = []
208+ expected_tokens : List [ Optional [ List [ int ]]] = []
175209 prompt_lens = []
176210 for i in range (batch_size ):
177- n = 1
178- sampling_type = random .randint (0 , 2 )
211+ expected : Optional [ List [ int ]] = None
212+ sampling_type = random .randint (0 , 3 )
179213 if sampling_type == 0 :
180214 sampling_params = SamplingParams (temperature = 0 )
181- elif sampling_type == 1 :
215+ expected = [torch .argmax (fake_logits [i ], dim = - 1 ).item ()]
216+ elif sampling_type in (1 , 2 ):
182217 n = random .randint (1 , 10 )
183218 sampling_params = SamplingParams (
184219 temperature = random .random () + 0.1 ,
@@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
187222 n = n ,
188223 presence_penalty = random .randint (0 , 1 ),
189224 )
225+ if sampling_type == 2 :
226+ sampling_params .seed = random .randint (0 , 10000 )
227+ else :
228+ for idx in range (n ):
229+ fake_logits [i , i + idx ] = 1e2
230+ expected = list (range (i , i + n ))
190231 else :
191232 sampling_params = SamplingParams (temperature = 0 ,
192233 use_beam_search = True ,
193234 best_of = 2 )
194- for idx in range (n ):
195- fake_logits [i , i + idx ] = 1e2
196- expected_tokens .append (i + idx )
235+ expected_tokens .append (expected )
197236 seq_group_metadata_list .append (
198237 SequenceGroupMetadata (
199238 request_id = f"test_{ i } " ,
@@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
204243 ))
205244 prompt_lens .append (seq_group_metadata_list [- 1 ].seq_data [0 ].get_len ())
206245
207- sampling_metadata = model_runner ._prepare_sample (seq_group_metadata_list ,
208- prompt_lens ,
209- subquery_lens = prompt_lens )
210- sampler_output = sampler (embedding = None ,
211- hidden_states = input_tensor ,
212- sampling_metadata = sampling_metadata )
213- for i , sequence_output in enumerate (sampler_output ):
214- if seq_group_metadata_list [i ].sampling_params .use_beam_search :
215- continue
216- for nth_output in sequence_output .samples :
217- assert nth_output .output_token in expected_tokens
246+ def test_sampling (model_runner : ModelRunner ):
247+ sampling_metadata = model_runner ._prepare_sample (
248+ seq_group_metadata_list , prompt_lens , subquery_lens = prompt_lens )
249+ sampler_output = sampler (embedding = None ,
250+ hidden_states = input_tensor ,
251+ sampling_metadata = sampling_metadata )
252+
253+ for i , (sequence_output , metadata ) in enumerate (
254+ zip (sampler_output , seq_group_metadata_list )):
255+ if metadata .sampling_params .use_beam_search :
256+ continue
257+
258+ if metadata .sampling_params .seed is not None \
259+ and expected_tokens [i ] is None :
260+ # Record seeded random result to compare with results of second invocation
261+ expected_tokens [i ] = [
262+ nth_output .output_token
263+ for nth_output in sequence_output .samples
264+ ]
265+ continue
266+
267+ for n , nth_output in enumerate (sequence_output .samples ):
268+ if metadata .sampling_params .temperature == 0 or metadata .sampling_params .seed is not None :
269+ # Ensure exact matches for greedy or random with seed
270+ assert nth_output .output_token == expected_tokens [i ][n ]
271+ else :
272+ # For non-seeded random check that one of the high-logit tokens were chosen
273+ assert nth_output .output_token in expected_tokens [i ]
274+
275+ # Test batch
276+ test_sampling (model_runner )
277+
278+ # Shuffle the batch and resample
279+ target_index = list (range (batch_size ))
280+ for list_to_shuffle in (target_index , seq_group_metadata_list ,
281+ expected_tokens , prompt_lens ):
282+ random .Random (seed ).shuffle (list_to_shuffle )
283+ target_index = torch .tensor (target_index )
284+ input_tensor .data = input_tensor .index_select (0 , target_index )
285+ fake_logits .data = fake_logits .index_select (0 , target_index )
286+
287+ # This time, results of seeded random samples will be compared with the corresponding
288+ # sample in the pre-shuffled batch
289+ test_sampling (model_runner )
218290
219291 del model_runner
220292
0 commit comments