@@ -37,13 +37,21 @@ class TestAdditionalOutputs:
3737    _sampling_parameters  =  {"temperature" : "0" , "top_p" : "1" }
3838    _prompt  =  "In this example," 
3939
40+     def  _get_sampling_parameters (self , logprobs = None ):
41+         sampling_parameters  =  self ._sampling_parameters .copy ()
42+         if  logprobs  is  not   None :
43+             sampling_parameters ["logprobs" ] =  logprobs 
44+         return  sampling_parameters 
45+ 
4046    def  _get_inputs (
4147        self ,
4248        prompt ,
4349        stream = True ,
4450        sampling_parameters = None ,
4551        return_finish_reason = None ,
4652        return_cumulative_logprob = None ,
53+         return_logprobs = None ,
54+         return_num_input_tokens = None ,
4755        return_num_output_tokens = None ,
4856    ):
4957        inputs  =  []
@@ -76,6 +84,16 @@ def _get_inputs(
7684                np .array ([return_cumulative_logprob ], dtype = bool )
7785            )
7886
87+         if  return_logprobs  is  not   None :
88+             inputs .append (grpcclient .InferInput ("return_logprobs" , [1 ], "BOOL" ))
89+             inputs [- 1 ].set_data_from_numpy (np .array ([return_logprobs ], dtype = bool ))
90+ 
91+         if  return_num_input_tokens  is  not   None :
92+             inputs .append (grpcclient .InferInput ("return_num_input_tokens" , [1 ], "BOOL" ))
93+             inputs [- 1 ].set_data_from_numpy (
94+                 np .array ([return_num_input_tokens ], dtype = bool )
95+             )
96+ 
7997        if  return_num_output_tokens  is  not   None :
8098            inputs .append (
8199                grpcclient .InferInput ("return_num_output_tokens" , [1 ], "BOOL" )
@@ -89,12 +107,12 @@ def _get_inputs(
89107    def  _callback (self , result , error ):
90108        self ._responses .append ({"result" : result , "error" : error })
91109
92-     def  _llm_infer (self , inputs ):
110+     def  _llm_infer (self , inputs ,  sampling_parameters ):
93111        self ._responses  =  []
94112        with  grpcclient .InferenceServerClient (self ._grpc_url ) as  client :
95113            client .start_stream (self ._callback )
96114            client .async_stream_infer (
97-                 self ._model_name , inputs = inputs , parameters = self . _sampling_parameters 
115+                 self ._model_name , inputs = inputs , parameters = sampling_parameters 
98116            )
99117            client .stop_stream ()
100118        assert  len (self ._responses ) >  0 
@@ -135,6 +153,63 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob):
135153            assert  cumulative_logprob  !=  prev_cumulative_logprob 
136154            prev_cumulative_logprob  =  cumulative_logprob 
137155
156+     def  _assert_logprobs (
157+         self , stream , sampling_parameters , return_logprobs , return_num_output_tokens 
158+     ):
159+         for  response  in  self ._responses :
160+             result , error  =  response ["result" ], response ["error" ]
161+             assert  error  is  None 
162+             logprobs_np  =  result .as_numpy (name = "logprobs" )
163+             if  return_logprobs  is  None  or  return_logprobs  ==  False :
164+                 assert  logprobs_np  is  None 
165+                 continue 
166+             logprobs  =  json .loads (logprobs_np [0 ].decode ("utf-8" ))
167+             if  "logprobs"  not  in   sampling_parameters :
168+                 assert  logprobs  is  None 
169+                 continue 
170+             assert  isinstance (logprobs , list )
171+             assert  len (logprobs ) >=  1 
172+             if  return_num_output_tokens  ==  True :
173+                 num_output_tokens  =  result .as_numpy (name = "num_output_tokens" )[0 ].astype (
174+                     int 
175+                 )
176+                 assert  len (logprobs ) ==  num_output_tokens 
177+             text_output_logprobs  =  "" 
178+             for  logprobs_d  in  logprobs :
179+                 assert  isinstance (logprobs_d , dict )
180+                 assert  len (logprobs_d ) >=  1 
181+                 assert  len (logprobs_d ) <=  sampling_parameters ["logprobs" ] +  1 
182+                 rank_one_found  =  False 
183+                 for  token_id , logprob_d  in  logprobs_d .items ():
184+                     assert  isinstance (token_id , str )
185+                     assert  len (logprob_d ) ==  3 
186+                     assert  isinstance (logprob_d ["logprob" ], float )
187+                     assert  isinstance (logprob_d ["rank" ], int )
188+                     assert  isinstance (logprob_d ["decoded_token" ], str )
189+                     if  logprob_d ["rank" ] ==  1 :
190+                         assert  not  rank_one_found 
191+                         rank_one_found  =  True 
192+                         text_output_logprobs  +=  logprob_d ["decoded_token" ]
193+                 assert  rank_one_found 
194+             text_output  =  result .as_numpy (name = "text_output" )[0 ].decode ("utf-8" )
195+             if  not  stream :
196+                 # given exclude_input_in_output is not set, prepend_input is True if not 
197+                 # streaming and False if streaming 
198+                 text_output_logprobs  =  self ._prompt  +  text_output_logprobs 
199+             assert  text_output_logprobs  ==  text_output 
200+ 
201+     def  _assert_num_input_tokens (self , return_num_input_tokens ):
202+         for  response  in  self ._responses :
203+             result , error  =  response ["result" ], response ["error" ]
204+             assert  error  is  None 
205+             num_input_tokens_np  =  result .as_numpy (name = "num_input_tokens" )
206+             if  return_num_input_tokens  is  None  or  return_num_input_tokens  ==  False :
207+                 assert  num_input_tokens_np  is  None 
208+                 continue 
209+             num_input_tokens  =  num_input_tokens_np .astype (int )
210+             assert  num_input_tokens  >  0 
211+             assert  num_input_tokens  <=  len (self ._prompt )
212+ 
138213    def  _assert_num_output_tokens (self , return_num_output_tokens ):
139214        for  response  in  self ._responses :
140215            result , error  =  response ["result" ], response ["error" ]
@@ -144,46 +219,42 @@ def _assert_num_output_tokens(self, return_num_output_tokens):
144219                assert  num_output_tokens_np  is  None 
145220                continue 
146221            num_output_tokens  =  num_output_tokens_np [0 ].astype (int )
147-             # TODO: vLLM may return token ids identical to the previous one when 
148-             #       streaming, for example: 
149-             # 
150-             #       prev: None 
151-             #       curr: text=' the', token_ids=array('l', [5]) 
152-             # 
153-             #       prev: text=' the', token_ids=array('l', [5, 1385]) 
154-             #       curr: text=' the term', token_ids=array('l', [5, 1385]) 
155-             # 
156-             #       prev: text=' the term', token_ids=array('l', [5, 1385, 44]) 
157-             #       curr: text=' the term', token_ids=array('l', [5, 1385, 44]) 
158-             # 
159-             #       prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48]) 
160-             #       curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48]) 
161-             # 
162-             #       If this is no longer the case in a future release, change the assert 
163-             #       to assert num_output_tokens > 0. 
164-             assert  num_output_tokens  >=  0 
222+             assert  num_output_tokens  >  0 
165223
166224    @pytest .mark .parametrize ("stream" , [True , False ]) 
167225    @pytest .mark .parametrize ("return_finish_reason" , [None , True , False ]) 
168226    @pytest .mark .parametrize ("return_cumulative_logprob" , [None , True , False ]) 
227+     @pytest .mark .parametrize ("logprobs" , [None , 0 , 2 ]) 
228+     @pytest .mark .parametrize ("return_logprobs" , [None , True , False ]) 
229+     @pytest .mark .parametrize ("return_num_input_tokens" , [None , True , False ]) 
169230    @pytest .mark .parametrize ("return_num_output_tokens" , [None , True , False ]) 
170231    def  test_additional_outputs (
171232        self ,
172233        stream ,
173234        return_finish_reason ,
174235        return_cumulative_logprob ,
236+         logprobs ,
237+         return_logprobs ,
238+         return_num_input_tokens ,
175239        return_num_output_tokens ,
176240    ):
241+         sampling_parameters  =  self ._get_sampling_parameters (logprobs = logprobs )
177242        inputs  =  self ._get_inputs (
178243            self ._prompt ,
179244            stream = stream ,
180-             sampling_parameters = self . _sampling_parameters ,
245+             sampling_parameters = sampling_parameters ,
181246            return_finish_reason = return_finish_reason ,
182247            return_cumulative_logprob = return_cumulative_logprob ,
248+             return_logprobs = return_logprobs ,
249+             return_num_input_tokens = return_num_input_tokens ,
183250            return_num_output_tokens = return_num_output_tokens ,
184251        )
185-         self ._llm_infer (inputs )
252+         self ._llm_infer (inputs ,  sampling_parameters )
186253        self ._assert_text_output_valid ()
187254        self ._assert_finish_reason (return_finish_reason )
188255        self ._assert_cumulative_logprob (return_cumulative_logprob )
256+         self ._assert_logprobs (
257+             stream , sampling_parameters , return_logprobs , return_num_output_tokens 
258+         )
259+         self ._assert_num_input_tokens (return_num_input_tokens )
189260        self ._assert_num_output_tokens (return_num_output_tokens )
0 commit comments