@@ -37,13 +37,20 @@ class TestAdditionalOutputs:
3737 _sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
3838 _prompt = "In this example,"
3939
40+ def _get_sampling_parameters (self , logprobs = None ):
41+ sampling_parameters = self ._sampling_parameters .copy ()
42+ if logprobs is not None :
43+ sampling_parameters ["logprobs" ] = logprobs
44+ return sampling_parameters
45+
4046 def _get_inputs (
4147 self ,
4248 prompt ,
4349 stream = True ,
4450 sampling_parameters = None ,
4551 return_finish_reason = None ,
4652 return_cumulative_logprob = None ,
53+ return_logprobs = None ,
4754 return_num_input_tokens = None ,
4855 return_num_output_tokens = None ,
4956 ):
@@ -77,6 +84,10 @@ def _get_inputs(
7784 np .array ([return_cumulative_logprob ], dtype = bool )
7885 )
7986
87+ if return_logprobs is not None :
88+ inputs .append (grpcclient .InferInput ("return_logprobs" , [1 ], "BOOL" ))
89+ inputs [- 1 ].set_data_from_numpy (np .array ([return_logprobs ], dtype = bool ))
90+
8091 if return_num_input_tokens is not None :
8192 inputs .append (grpcclient .InferInput ("return_num_input_tokens" , [1 ], "BOOL" ))
8293 inputs [- 1 ].set_data_from_numpy (
@@ -96,12 +107,12 @@ def _get_inputs(
96107 def _callback (self , result , error ):
97108 self ._responses .append ({"result" : result , "error" : error })
98109
99- def _llm_infer (self , inputs ):
110+ def _llm_infer (self , inputs , sampling_parameters ):
100111 self ._responses = []
101112 with grpcclient .InferenceServerClient (self ._grpc_url ) as client :
102113 client .start_stream (self ._callback )
103114 client .async_stream_infer (
104- self ._model_name , inputs = inputs , parameters = self . _sampling_parameters
115+ self ._model_name , inputs = inputs , parameters = sampling_parameters
105116 )
106117 client .stop_stream ()
107118 assert len (self ._responses ) > 0
@@ -142,6 +153,51 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob):
142153 assert cumulative_logprob != prev_cumulative_logprob
143154 prev_cumulative_logprob = cumulative_logprob
144155
156+ def _assert_logprobs (
157+ self , stream , sampling_parameters , return_logprobs , return_num_output_tokens
158+ ):
159+ for response in self ._responses :
160+ result , error = response ["result" ], response ["error" ]
161+ assert error is None
162+ logprobs_np = result .as_numpy (name = "logprobs" )
163+ if return_logprobs is None or return_logprobs == False :
164+ assert logprobs_np is None
165+ continue
166+ logprobs = json .loads (logprobs_np [0 ].decode ("utf-8" ))
167+ if "logprobs" not in sampling_parameters :
168+ assert logprobs is None
169+ continue
170+ assert isinstance (logprobs , list )
171+ assert len (logprobs ) >= 1
172+ if return_num_output_tokens == True :
173+ num_output_tokens = result .as_numpy (name = "num_output_tokens" )[0 ].astype (
174+ int
175+ )
176+ assert len (logprobs ) == num_output_tokens
177+ text_output_logprobs = ""
178+ for logprobs_d in logprobs :
179+ assert isinstance (logprobs_d , dict )
180+ assert len (logprobs_d ) >= 1
181+ assert len (logprobs_d ) <= sampling_parameters ["logprobs" ] + 1
182+ rank_one_found = False
183+ for token_id , logprob_d in logprobs_d .items ():
184+ assert isinstance (token_id , str )
185+ assert len (logprob_d ) == 3
186+ assert isinstance (logprob_d ["logprob" ], float )
187+ assert isinstance (logprob_d ["rank" ], int )
188+ assert isinstance (logprob_d ["decoded_token" ], str )
189+ if logprob_d ["rank" ] == 1 :
190+ assert not rank_one_found
191+ rank_one_found = True
192+ text_output_logprobs += logprob_d ["decoded_token" ]
193+ assert rank_one_found
194+ text_output = result .as_numpy (name = "text_output" )[0 ].decode ("utf-8" )
195+ if not stream :
196+ # given exclude_input_in_output is not set, prepend_input is True if not
197+ # streaming and False if streaming
198+ text_output_logprobs = self ._prompt + text_output_logprobs
199+ assert text_output_logprobs == text_output
200+
145201 def _assert_num_input_tokens (self , return_num_input_tokens ):
146202 for response in self ._responses :
147203 result , error = response ["result" ], response ["error" ]
@@ -163,50 +219,42 @@ def _assert_num_output_tokens(self, return_num_output_tokens):
163219 assert num_output_tokens_np is None
164220 continue
165221 num_output_tokens = num_output_tokens_np [0 ].astype (int )
166- # TODO: vLLM may return token ids identical to the previous one when
167- # streaming, for example:
168- #
169- # prev: None
170- # curr: text=' the', token_ids=array('l', [5])
171- #
172- # prev: text=' the', token_ids=array('l', [5, 1385])
173- # curr: text=' the term', token_ids=array('l', [5, 1385])
174- #
175- # prev: text=' the term', token_ids=array('l', [5, 1385, 44])
176- # curr: text=' the term', token_ids=array('l', [5, 1385, 44])
177- #
178- # prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
179- # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
180- #
181- # If this is no longer the case in a future release, change the assert
182- # to assert num_output_tokens > 0.
183- assert num_output_tokens >= 0
222+ assert num_output_tokens > 0
184223
185224 @pytest .mark .parametrize ("stream" , [True , False ])
186225 @pytest .mark .parametrize ("return_finish_reason" , [None , True , False ])
187226 @pytest .mark .parametrize ("return_cumulative_logprob" , [None , True , False ])
227+ @pytest .mark .parametrize ("logprobs" , [None , 0 , 2 ])
228+ @pytest .mark .parametrize ("return_logprobs" , [None , True , False ])
188229 @pytest .mark .parametrize ("return_num_input_tokens" , [None , True , False ])
189230 @pytest .mark .parametrize ("return_num_output_tokens" , [None , True , False ])
190231 def test_additional_outputs (
191232 self ,
192233 stream ,
193234 return_finish_reason ,
194235 return_cumulative_logprob ,
236+ logprobs ,
237+ return_logprobs ,
195238 return_num_input_tokens ,
196239 return_num_output_tokens ,
197240 ):
241+ sampling_parameters = self ._get_sampling_parameters (logprobs = logprobs )
198242 inputs = self ._get_inputs (
199243 self ._prompt ,
200244 stream = stream ,
201- sampling_parameters = self . _sampling_parameters ,
245+ sampling_parameters = sampling_parameters ,
202246 return_finish_reason = return_finish_reason ,
203247 return_cumulative_logprob = return_cumulative_logprob ,
248+ return_logprobs = return_logprobs ,
204249 return_num_input_tokens = return_num_input_tokens ,
205250 return_num_output_tokens = return_num_output_tokens ,
206251 )
207- self ._llm_infer (inputs )
252+ self ._llm_infer (inputs , sampling_parameters )
208253 self ._assert_text_output_valid ()
209254 self ._assert_finish_reason (return_finish_reason )
210255 self ._assert_cumulative_logprob (return_cumulative_logprob )
256+ self ._assert_logprobs (
257+ stream , sampling_parameters , return_logprobs , return_num_output_tokens
258+ )
211259 self ._assert_num_input_tokens (return_num_input_tokens )
212260 self ._assert_num_output_tokens (return_num_output_tokens )
0 commit comments