@@ -44,6 +44,7 @@ def _get_inputs(
4444 sampling_parameters = None ,
4545 return_finish_reason = None ,
4646 return_cumulative_logprob = None ,
47+ return_num_input_tokens = None ,
4748 return_num_output_tokens = None ,
4849 ):
4950 inputs = []
@@ -76,6 +77,12 @@ def _get_inputs(
7677 np .array ([return_cumulative_logprob ], dtype = bool )
7778 )
7879
80+ if return_num_input_tokens is not None :
81+ inputs .append (grpcclient .InferInput ("return_num_input_tokens" , [1 ], "BOOL" ))
82+ inputs [- 1 ].set_data_from_numpy (
83+ np .array ([return_num_input_tokens ], dtype = bool )
84+ )
85+
7986 if return_num_output_tokens is not None :
8087 inputs .append (
8188 grpcclient .InferInput ("return_num_output_tokens" , [1 ], "BOOL" )
@@ -135,6 +142,18 @@ def _assert_cumulative_logprob(self, return_cumulative_logprob):
135142 assert cumulative_logprob != prev_cumulative_logprob
136143 prev_cumulative_logprob = cumulative_logprob
137144
145+ def _assert_num_input_tokens (self , return_num_input_tokens ):
146+ for response in self ._responses :
147+ result , error = response ["result" ], response ["error" ]
148+ assert error is None
149+ num_input_tokens_np = result .as_numpy (name = "num_input_tokens" )
150+ if return_num_input_tokens is None or return_num_input_tokens == False :
151+ assert num_input_tokens_np is None
152+ continue
153+ num_input_tokens = num_input_tokens_np .astype (int )
154+ assert num_input_tokens > 0
155+ assert num_input_tokens <= len (self ._prompt )
156+
138157 def _assert_num_output_tokens (self , return_num_output_tokens ):
139158 for response in self ._responses :
140159 result , error = response ["result" ], response ["error" ]
@@ -166,12 +185,14 @@ def _assert_num_output_tokens(self, return_num_output_tokens):
166185 @pytest .mark .parametrize ("stream" , [True , False ])
167186 @pytest .mark .parametrize ("return_finish_reason" , [None , True , False ])
168187 @pytest .mark .parametrize ("return_cumulative_logprob" , [None , True , False ])
188+ @pytest .mark .parametrize ("return_num_input_tokens" , [None , True , False ])
169189 @pytest .mark .parametrize ("return_num_output_tokens" , [None , True , False ])
170190 def test_additional_outputs (
171191 self ,
172192 stream ,
173193 return_finish_reason ,
174194 return_cumulative_logprob ,
195+ return_num_input_tokens ,
175196 return_num_output_tokens ,
176197 ):
177198 inputs = self ._get_inputs (
@@ -180,10 +201,12 @@ def test_additional_outputs(
180201 sampling_parameters = self ._sampling_parameters ,
181202 return_finish_reason = return_finish_reason ,
182203 return_cumulative_logprob = return_cumulative_logprob ,
204+ return_num_input_tokens = return_num_input_tokens ,
183205 return_num_output_tokens = return_num_output_tokens ,
184206 )
185207 self ._llm_infer (inputs )
186208 self ._assert_text_output_valid ()
187209 self ._assert_finish_reason (return_finish_reason )
188210 self ._assert_cumulative_logprob (return_cumulative_logprob )
211+ self ._assert_num_input_tokens (return_num_input_tokens )
189212 self ._assert_num_output_tokens (return_num_output_tokens )
0 commit comments