@@ -66,7 +66,11 @@ def auto_complete_config(auto_complete_model_config):
6666                "optional" : True ,
6767            },
6868        ]
69-         outputs  =  [{"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]}]
69+         outputs  =  [
70+             {"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]},
71+             {"name" : "input_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
72+             {"name" : "output_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
73+         ]
7074
7175        # Store the model configuration as a dictionary. 
7276        config  =  auto_complete_model_config .as_dict ()
@@ -151,6 +155,15 @@ def initialize(self, args):
151155        )
152156        self .output_dtype  =  pb_utils .triton_string_to_numpy (output_config ["data_type" ])
153157
158+         output_tokens_config  =  pb_utils .get_output_config_by_name (
159+             self .model_config , "output_tokens" 
160+         )
161+         self .output_tokens_dtype  =  pb_utils .triton_string_to_numpy (output_tokens_config ["data_type" ])
162+         input_tokens_config  =  pb_utils .get_output_config_by_name (
163+             self .model_config , "input_tokens" 
164+         )
165+         self .input_tokens_dtype  =  pb_utils .triton_string_to_numpy (input_tokens_config ["data_type" ])
166+ 
154167        # Counter to keep track of ongoing request counts 
155168        self .ongoing_request_count  =  0 
156169
@@ -246,10 +259,17 @@ def create_response(self, vllm_output, prepend_input):
246259        text_outputs  =  [
247260            (prompt  +  output .text ).encode ("utf-8" ) for  output  in  vllm_output .outputs 
248261        ]
262+         output_tokens  =  sum ([len (output .token_ids ) for  output  in  vllm_output .outputs ])
249263        triton_output_tensor  =  pb_utils .Tensor (
250-             "text_output" , np .asarray (text_outputs , dtype = self .output_dtype )
264+             "text_output" , np .asarray (text_outputs , dtype = self .output_dtype ), 
251265        )
252-         return  pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
266+         triton_tokens_tensor  =  pb_utils .Tensor (
267+             "output_tokens" , np .asarray (output_tokens , dtype = self .output_tokens_dtype ),
268+         )
269+         triton_input_tokens_tensor  =  pb_utils .Tensor (
270+             "input_tokens" , np .asarray (len (vllm_output .prompt_token_ids ), dtype = self .input_tokens_dtype ),
271+         )
272+         return  pb_utils .InferenceResponse (output_tensors = [triton_output_tensor , triton_tokens_tensor , triton_input_tokens_tensor ])
253273
254274    def  create_stream_response (self , vllm_output , previous_outputs_lengths ):
255275        """ 
0 commit comments