2525# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727import json
28- import unittest
2928
3029import numpy as np
30+ import pytest
3131import tritonclient .grpc as grpcclient
3232
3333
34- class InferTest ( unittest . TestCase ) :
34+ class TestAdditionalOutputs :
3535 _grpc_url = "localhost:8001"
3636 _model_name = "vllm_opt"
3737 _sampling_parameters = {"temperature" : "0" , "top_p" : "1" }
@@ -93,51 +93,51 @@ def _llm_infer(self, inputs):
9393 self ._model_name , inputs = inputs , parameters = self ._sampling_parameters
9494 )
9595 client .stop_stream ()
96- self . assertGreater ( len (self ._responses ), 0 )
96+ assert len (self ._responses ) > 0
9797
9898 def _assert_text_output_valid (self ):
9999 text_output = ""
100100 for response in self ._responses :
101101 result , error = response ["result" ], response ["error" ]
102- self . assertIsNone ( error )
102+ assert error is None
103103 text_output += result .as_numpy (name = "text_output" )[0 ].decode ("utf-8" )
104- self . assertGreater ( len (text_output ), 0 , "output is empty" )
105- self . assertGreater ( text_output .count (" " ), 4 , "output is not a sentence" )
104+ assert len (text_output ) > 0 , "output is empty"
105+ assert text_output .count (" " ) > 4 , "output is not a sentence"
106106
107107 def _assert_finish_reason (self , output_finish_reason ):
108108 for i in range (len (self ._responses )):
109109 result , error = self ._responses [i ]["result" ], self ._responses [i ]["error" ]
110- self . assertIsNone ( error )
110+ assert error is None
111111 finish_reason_np = result .as_numpy (name = "finish_reason" )
112112 if output_finish_reason is None or output_finish_reason == False :
113- self . assertIsNone ( finish_reason_np )
113+ assert finish_reason_np is None
114114 continue
115115 finish_reason = finish_reason_np [0 ].decode ("utf-8" )
116116 if i < len (self ._responses ) - 1 :
117- self . assertEqual ( finish_reason , "None" )
117+ assert finish_reason == "None"
118118 else :
119- self . assertEqual ( finish_reason , "length" )
119+ assert finish_reason == "length"
120120
121121 def _assert_cumulative_logprob (self , output_cumulative_logprob ):
122122 prev_cumulative_logprob = 0.0
123123 for response in self ._responses :
124124 result , error = response ["result" ], response ["error" ]
125- self . assertIsNone ( error )
125+ assert error is None
126126 cumulative_logprob_np = result .as_numpy (name = "cumulative_logprob" )
127127 if output_cumulative_logprob is None or output_cumulative_logprob == False :
128- self . assertIsNone ( cumulative_logprob_np )
128+ assert cumulative_logprob_np is None
129129 continue
130130 cumulative_logprob = cumulative_logprob_np [0 ].astype (float )
131- self . assertNotEqual ( cumulative_logprob , prev_cumulative_logprob )
131+ assert cumulative_logprob != prev_cumulative_logprob
132132 prev_cumulative_logprob = cumulative_logprob
133133
134134 def _assert_num_token_ids (self , output_num_token_ids ):
135135 for response in self ._responses :
136136 result , error = response ["result" ], response ["error" ]
137- self . assertIsNone ( error )
137+ assert error is None
138138 num_token_ids_np = result .as_numpy (name = "num_token_ids" )
139139 if output_num_token_ids is None or output_num_token_ids == False :
140- self . assertIsNone ( num_token_ids_np )
140+ assert num_token_ids_np is None
141141 continue
142142 num_token_ids = num_token_ids_np [0 ].astype (int )
143143 # TODO: vLLM may return token ids identical to the previous one when
@@ -156,10 +156,14 @@ def _assert_num_token_ids(self, output_num_token_ids):
156156 # curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
157157 #
158158 # If this is no longer the case in a future release, change the assert
159- # to assertGreater().
160- self .assertGreaterEqual (num_token_ids , 0 )
161-
162- def _assert_additional_outputs_valid (
159+ # to assert num_token_ids > 0.
160+ assert num_token_ids >= 0
161+
162+ @pytest .mark .parametrize ("stream" , [True , False ])
163+ @pytest .mark .parametrize ("output_finish_reason" , [None , True , False ])
164+ @pytest .mark .parametrize ("output_cumulative_logprob" , [None , True , False ])
165+ @pytest .mark .parametrize ("output_num_token_ids" , [None , True , False ])
166+ def test_additional_outputs (
163167 self ,
164168 stream ,
165169 output_finish_reason ,
@@ -179,20 +183,3 @@ def _assert_additional_outputs_valid(
179183 self ._assert_finish_reason (output_finish_reason )
180184 self ._assert_cumulative_logprob (output_cumulative_logprob )
181185 self ._assert_num_token_ids (output_num_token_ids )
182-
183- def test_additional_outputs (self ):
184- for stream in [True , False ]:
185- choices = [None , False , True ]
186- for output_finish_reason in choices :
187- for output_cumulative_logprob in choices :
188- for output_num_token_ids in choices :
189- self ._assert_additional_outputs_valid (
190- stream ,
191- output_finish_reason ,
192- output_cumulative_logprob ,
193- output_num_token_ids ,
194- )
195-
196-
197- if __name__ == "__main__" :
198- unittest .main ()
0 commit comments