@@ -42,9 +42,9 @@ def _get_inputs(
4242 prompt ,
4343 stream = True ,
4444 sampling_parameters = None ,
45- output_finish_reason = None ,
46- output_cumulative_logprob = None ,
47- output_num_token_ids = None ,
45+ return_finish_reason = None ,
46+ return_cumulative_logprob = None ,
47+ return_num_token_ids = None ,
4848 ):
4949 inputs = []
5050
@@ -64,21 +64,21 @@ def _get_inputs(
6464 )
6565 )
6666
67- if output_finish_reason is not None :
68- inputs .append (grpcclient .InferInput ("output_finish_reason " , [1 ], "BOOL" ))
69- inputs [- 1 ].set_data_from_numpy (np .array ([output_finish_reason ], dtype = bool ))
67+ if return_finish_reason is not None :
68+ inputs .append (grpcclient .InferInput ("return_finish_reason " , [1 ], "BOOL" ))
69+ inputs [- 1 ].set_data_from_numpy (np .array ([return_finish_reason ], dtype = bool ))
7070
71- if output_cumulative_logprob is not None :
71+ if return_cumulative_logprob is not None :
7272 inputs .append (
73- grpcclient .InferInput ("output_cumulative_logprob " , [1 ], "BOOL" )
73+ grpcclient .InferInput ("return_cumulative_logprob " , [1 ], "BOOL" )
7474 )
7575 inputs [- 1 ].set_data_from_numpy (
76- np .array ([output_cumulative_logprob ], dtype = bool )
76+ np .array ([return_cumulative_logprob ], dtype = bool )
7777 )
7878
79- if output_num_token_ids is not None :
80- inputs .append (grpcclient .InferInput ("output_num_token_ids " , [1 ], "BOOL" ))
81- inputs [- 1 ].set_data_from_numpy (np .array ([output_num_token_ids ], dtype = bool ))
79+ if return_num_token_ids is not None :
80+ inputs .append (grpcclient .InferInput ("return_num_token_ids " , [1 ], "BOOL" ))
81+ inputs [- 1 ].set_data_from_numpy (np .array ([return_num_token_ids ], dtype = bool ))
8282
8383 return inputs
8484
@@ -104,12 +104,12 @@ def _assert_text_output_valid(self):
104104 assert len (text_output ) > 0 , "output is empty"
105105 assert text_output .count (" " ) > 4 , "output is not a sentence"
106106
107- def _assert_finish_reason (self , output_finish_reason ):
107+ def _assert_finish_reason (self , return_finish_reason ):
108108 for i in range (len (self ._responses )):
109109 result , error = self ._responses [i ]["result" ], self ._responses [i ]["error" ]
110110 assert error is None
111111 finish_reason_np = result .as_numpy (name = "finish_reason" )
112- if output_finish_reason is None or output_finish_reason == False :
112+ if return_finish_reason is None or return_finish_reason == False :
113113 assert finish_reason_np is None
114114 continue
115115 finish_reason = finish_reason_np [0 ].decode ("utf-8" )
@@ -118,25 +118,25 @@ def _assert_finish_reason(self, output_finish_reason):
118118 else :
119119 assert finish_reason == "length"
120120
121- def _assert_cumulative_logprob (self , output_cumulative_logprob ):
121+ def _assert_cumulative_logprob (self , return_cumulative_logprob ):
122122 prev_cumulative_logprob = 0.0
123123 for response in self ._responses :
124124 result , error = response ["result" ], response ["error" ]
125125 assert error is None
126126 cumulative_logprob_np = result .as_numpy (name = "cumulative_logprob" )
127- if output_cumulative_logprob is None or output_cumulative_logprob == False :
127+ if return_cumulative_logprob is None or return_cumulative_logprob == False :
128128 assert cumulative_logprob_np is None
129129 continue
130130 cumulative_logprob = cumulative_logprob_np [0 ].astype (float )
131131 assert cumulative_logprob != prev_cumulative_logprob
132132 prev_cumulative_logprob = cumulative_logprob
133133
134- def _assert_num_token_ids (self , output_num_token_ids ):
134+ def _assert_num_token_ids (self , return_num_token_ids ):
135135 for response in self ._responses :
136136 result , error = response ["result" ], response ["error" ]
137137 assert error is None
138138 num_token_ids_np = result .as_numpy (name = "num_token_ids" )
139- if output_num_token_ids is None or output_num_token_ids == False :
139+ if return_num_token_ids is None or return_num_token_ids == False :
140140 assert num_token_ids_np is None
141141 continue
142142 num_token_ids = num_token_ids_np [0 ].astype (int )
@@ -160,26 +160,26 @@ def _assert_num_token_ids(self, output_num_token_ids):
160160 assert num_token_ids >= 0
161161
162162 @pytest .mark .parametrize ("stream" , [True , False ])
163- @pytest .mark .parametrize ("output_finish_reason " , [None , True , False ])
164- @pytest .mark .parametrize ("output_cumulative_logprob " , [None , True , False ])
165- @pytest .mark .parametrize ("output_num_token_ids " , [None , True , False ])
163+ @pytest .mark .parametrize ("return_finish_reason " , [None , True , False ])
164+ @pytest .mark .parametrize ("return_cumulative_logprob " , [None , True , False ])
165+ @pytest .mark .parametrize ("return_num_token_ids " , [None , True , False ])
166166 def test_additional_outputs (
167167 self ,
168168 stream ,
169- output_finish_reason ,
170- output_cumulative_logprob ,
171- output_num_token_ids ,
169+ return_finish_reason ,
170+ return_cumulative_logprob ,
171+ return_num_token_ids ,
172172 ):
173173 inputs = self ._get_inputs (
174174 self ._prompt ,
175175 stream = stream ,
176176 sampling_parameters = self ._sampling_parameters ,
177- output_finish_reason = output_finish_reason ,
178- output_cumulative_logprob = output_cumulative_logprob ,
179- output_num_token_ids = output_num_token_ids ,
177+ return_finish_reason = return_finish_reason ,
178+ return_cumulative_logprob = return_cumulative_logprob ,
179+ return_num_token_ids = return_num_token_ids ,
180180 )
181181 self ._llm_infer (inputs )
182182 self ._assert_text_output_valid ()
183- self ._assert_finish_reason (output_finish_reason )
184- self ._assert_cumulative_logprob (output_cumulative_logprob )
185- self ._assert_num_token_ids (output_num_token_ids )
183+ self ._assert_finish_reason (return_finish_reason )
184+ self ._assert_cumulative_logprob (return_cumulative_logprob )
185+ self ._assert_num_token_ids (return_num_token_ids )
0 commit comments