Skip to content

Commit d17512b

Browse files
authored
feat: Add support for usage in the OpenAI frontend vLLM backend (#8264)
1 parent 2e8de23 commit d17512b

File tree

8 files changed

+573
-28
lines changed

8 files changed

+573
-28
lines changed

python/openai/README.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/
9898

9999
```json
100100
{
101-
"id": "cmpl-6930b296-7ef8-11ef-bdd1-107c6149ca79",
101+
"id": "cmpl-0242093d-51ae-11f0-b339-e7480668bfbe",
102102
"choices": [
103103
{
104104
"finish_reason": "stop",
@@ -113,11 +113,15 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/
113113
"logprobs": null
114114
}
115115
],
116-
"created": 1727679085,
116+
"created": 1750846825,
117117
"model": "llama-3.1-8b-instruct",
118118
"system_fingerprint": null,
119119
"object": "chat.completion",
120-
"usage": null
120+
"usage": {
121+
"completion_tokens": 7,
122+
"prompt_tokens": 42,
123+
"total_tokens": 49
124+
}
121125
}
122126
```
123127

@@ -138,20 +142,24 @@ curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json'
138142

139143
```json
140144
{
141-
"id": "cmpl-d51df75c-7ef8-11ef-bdd1-107c6149ca79",
145+
"id": "cmpl-58fba3a0-51ae-11f0-859d-e7480668bfbe",
142146
"choices": [
143147
{
144148
"finish_reason": "stop",
145149
"index": 0,
146150
"logprobs": null,
147-
"text": " a field of computer science that focuses on developing algorithms that allow computers to learn from"
151+
"text": " an amazing field that can truly understand the hidden patterns that exist in the data,"
148152
}
149153
],
150-
"created": 1727679266,
154+
"created": 1750846970,
151155
"model": "llama-3.1-8b-instruct",
152156
"system_fingerprint": null,
153157
"object": "text_completion",
154-
"usage": null
158+
"usage": {
159+
"completion_tokens": 16,
160+
"prompt_tokens": 4,
161+
"total_tokens": 20
162+
}
155163
}
156164
```
157165

python/openai/openai_frontend/engine/triton_engine.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@
5151
_create_trtllm_inference_request,
5252
_create_vllm_inference_request,
5353
_get_output,
54+
_get_usage_from_response,
5455
_get_vllm_lora_names,
56+
_StreamingUsageAccumulator,
5557
_validate_triton_responses_non_streaming,
5658
)
5759
from schemas.openai import (
@@ -65,6 +67,7 @@
6567
ChatCompletionStreamResponseDelta,
6668
ChatCompletionToolChoiceOption1,
6769
Choice,
70+
CompletionUsage,
6871
CreateChatCompletionRequest,
6972
CreateChatCompletionResponse,
7073
CreateChatCompletionStreamResponse,
@@ -229,6 +232,8 @@ async def chat(
229232
backend=metadata.backend,
230233
)
231234

235+
usage = _get_usage_from_response(response, metadata.backend)
236+
232237
return CreateChatCompletionResponse(
233238
id=request_id,
234239
choices=[
@@ -243,6 +248,7 @@ async def chat(
243248
model=request.model,
244249
system_fingerprint=None,
245250
object=ObjectType.chat_completion,
251+
usage=usage,
246252
)
247253

248254
def _get_chat_completion_response_message(
@@ -319,7 +325,7 @@ async def completion(
319325
created = int(time.time())
320326
if request.stream:
321327
return self._streaming_completion_iterator(
322-
request_id, created, request.model, responses
328+
request_id, created, request, responses, metadata.backend
323329
)
324330

325331
# Response validation with decoupled models in mind
@@ -328,6 +334,8 @@ async def completion(
328334
response = responses[0]
329335
text = _get_output(response)
330336

337+
usage = _get_usage_from_response(response, metadata.backend)
338+
331339
choice = Choice(
332340
finish_reason=FinishReason.stop,
333341
index=0,
@@ -341,6 +349,7 @@ async def completion(
341349
object=ObjectType.text_completion,
342350
created=created,
343351
model=request.model,
352+
usage=usage,
344353
)
345354

346355
# TODO: This behavior should be tested further
@@ -421,6 +430,7 @@ def _get_streaming_chat_response_chunk(
421430
request_id: str,
422431
created: int,
423432
model: str,
433+
usage: Optional[CompletionUsage] = None,
424434
) -> CreateChatCompletionStreamResponse:
425435
return CreateChatCompletionStreamResponse(
426436
id=request_id,
@@ -429,6 +439,7 @@ def _get_streaming_chat_response_chunk(
429439
model=model,
430440
system_fingerprint=None,
431441
object=ObjectType.chat_completion_chunk,
442+
usage=usage,
432443
)
433444

434445
def _get_first_streaming_chat_response(
@@ -444,7 +455,7 @@ def _get_first_streaming_chat_response(
444455
finish_reason=None,
445456
)
446457
chunk = self._get_streaming_chat_response_chunk(
447-
choice, request_id, created, model
458+
choice, request_id, created, model, usage=None
448459
)
449460
return chunk
450461

@@ -470,6 +481,8 @@ async def _streaming_chat_iterator(
470481
)
471482

472483
previous_text = ""
484+
include_usage = request.stream_options and request.stream_options.include_usage
485+
usage_accumulator = _StreamingUsageAccumulator(backend)
473486

474487
chunk = self._get_first_streaming_chat_response(
475488
request_id, created, model, role
@@ -478,6 +491,8 @@ async def _streaming_chat_iterator(
478491

479492
async for response in responses:
480493
delta_text = _get_output(response)
494+
if include_usage:
495+
usage_accumulator.update(response)
481496

482497
(
483498
response_delta,
@@ -512,10 +527,25 @@ async def _streaming_chat_iterator(
512527
)
513528

514529
chunk = self._get_streaming_chat_response_chunk(
515-
choice, request_id, created, model
530+
choice, request_id, created, model, usage=None
516531
)
517532
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
518533

534+
# Send the final usage chunk if requested via stream_options.
535+
if include_usage:
536+
usage_payload = usage_accumulator.get_final_usage()
537+
if usage_payload:
538+
final_usage_chunk = CreateChatCompletionStreamResponse(
539+
id=request_id,
540+
choices=[],
541+
created=created,
542+
model=model,
543+
system_fingerprint=None,
544+
object=ObjectType.chat_completion_chunk,
545+
usage=usage_payload,
546+
)
547+
yield f"data: {final_usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
548+
519549
yield "data: [DONE]\n\n"
520550

521551
def _get_streaming_response_delta(
@@ -662,6 +692,18 @@ def _validate_chat_request(
662692

663693
self._verify_chat_tool_call_settings(request=request)
664694

695+
if request.stream_options and not request.stream:
696+
raise Exception("`stream_options` can only be used when `stream` is True")
697+
698+
if (
699+
request.stream_options
700+
and request.stream_options.include_usage
701+
and metadata.backend != "vllm"
702+
):
703+
raise Exception(
704+
"`stream_options.include_usage` is currently only supported for the vLLM backend"
705+
)
706+
665707
def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest):
666708
if (
667709
request.tool_choice
@@ -698,9 +740,21 @@ def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest):
698740
)
699741

700742
async def _streaming_completion_iterator(
701-
self, request_id: str, created: int, model: str, responses: AsyncIterable
743+
self,
744+
request_id: str,
745+
created: int,
746+
request: CreateCompletionRequest,
747+
responses: AsyncIterable,
748+
backend: str,
702749
) -> AsyncIterator[str]:
750+
model = request.model
751+
include_usage = request.stream_options and request.stream_options.include_usage
752+
usage_accumulator = _StreamingUsageAccumulator(backend)
753+
703754
async for response in responses:
755+
if include_usage:
756+
usage_accumulator.update(response)
757+
704758
text = _get_output(response)
705759
choice = Choice(
706760
finish_reason=FinishReason.stop if response.final else None,
@@ -715,10 +769,26 @@ async def _streaming_completion_iterator(
715769
object=ObjectType.text_completion,
716770
created=created,
717771
model=model,
772+
usage=None,
718773
)
719774

720775
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
721776

777+
# Send the final usage chunk if requested via stream_options.
778+
if include_usage:
779+
usage_payload = usage_accumulator.get_final_usage()
780+
if usage_payload:
781+
final_usage_chunk = CreateCompletionResponse(
782+
id=request_id,
783+
choices=[],
784+
system_fingerprint=None,
785+
object=ObjectType.text_completion,
786+
created=created,
787+
model=model,
788+
usage=usage_payload,
789+
)
790+
yield f"data: {final_usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
791+
722792
yield "data: [DONE]\n\n"
723793

724794
def _validate_completion_request(
@@ -771,6 +841,18 @@ def _validate_completion_request(
771841
if request.logit_bias is not None or request.logprobs is not None:
772842
raise Exception("logit bias and log probs not supported")
773843

844+
if request.stream_options and not request.stream:
845+
raise Exception("`stream_options` can only be used when `stream` is True")
846+
847+
if (
848+
request.stream_options
849+
and request.stream_options.include_usage
850+
and metadata.backend != "vllm"
851+
):
852+
raise Exception(
853+
"`stream_options.include_usage` is currently only supported for the vLLM backend"
854+
)
855+
774856
def _should_stream_with_auto_tool_parsing(
775857
self, request: CreateChatCompletionRequest
776858
):

python/openai/openai_frontend/engine/utils/triton.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import json
2828
import os
2929
import re
30-
from dataclasses import asdict
30+
from dataclasses import asdict, dataclass, field
3131
from typing import Iterable, List, Optional, Union
3232

3333
import numpy as np
@@ -36,6 +36,7 @@
3636
from schemas.openai import (
3737
ChatCompletionNamedToolChoice,
3838
ChatCompletionToolChoiceOption1,
39+
CompletionUsage,
3940
CreateChatCompletionRequest,
4041
CreateCompletionRequest,
4142
)
@@ -121,6 +122,8 @@ def _create_vllm_inference_request(
121122
# Pass sampling_parameters as serialized JSON string input to support List
122123
# fields like 'stop' that aren't supported by TRITONSERVER_Parameters yet.
123124
inputs["sampling_parameters"] = [sampling_parameters]
125+
inputs["return_num_input_tokens"] = np.bool_([True])
126+
inputs["return_num_output_tokens"] = np.bool_([True])
124127
return model.create_request(inputs=inputs)
125128

126129

@@ -221,6 +224,85 @@ def _to_string(tensor: tritonserver.Tensor) -> str:
221224
return _construct_string_from_pointer(tensor.data_ptr + 4, tensor.size - 4)
222225

223226

227+
@dataclass
228+
class _StreamingUsageAccumulator:
229+
"""Helper class to accumulate token usage from a streaming response."""
230+
231+
backend: str
232+
prompt_tokens: int = 0
233+
completion_tokens: int = 0
234+
_prompt_tokens_set: bool = field(init=False, default=False)
235+
236+
def update(self, response: tritonserver.InferenceResponse):
237+
"""Extracts usage from a response and updates the token counts."""
238+
usage = _get_usage_from_response(response, self.backend)
239+
if usage:
240+
# The prompt_tokens is received with every chunk but should only be set once.
241+
if not self._prompt_tokens_set:
242+
self.prompt_tokens = usage.prompt_tokens
243+
self._prompt_tokens_set = True
244+
self.completion_tokens += usage.completion_tokens
245+
246+
def get_final_usage(self) -> Optional[CompletionUsage]:
247+
"""
248+
Returns the final populated CompletionUsage object if any tokens were tracked.
249+
"""
250+
# If _prompt_tokens_set is True, it means we have received and processed
251+
# at least one valid usage payload.
252+
if self._prompt_tokens_set:
253+
return CompletionUsage(
254+
prompt_tokens=self.prompt_tokens,
255+
completion_tokens=self.completion_tokens,
256+
total_tokens=self.prompt_tokens + self.completion_tokens,
257+
)
258+
return None
259+
260+
261+
def _get_usage_from_response(
262+
response: tritonserver._api._response.InferenceResponse,
263+
backend: str,
264+
) -> Optional[CompletionUsage]:
265+
"""
266+
Extracts token usage statistics from a Triton inference response.
267+
"""
268+
# TODO: Remove this check once TRT-LLM backend supports both "num_input_tokens"
269+
# and "num_output_tokens", and also update the test cases accordingly.
270+
if backend != "vllm":
271+
return None
272+
273+
prompt_tokens = None
274+
completion_tokens = None
275+
276+
if (
277+
"num_input_tokens" in response.outputs
278+
and "num_output_tokens" in response.outputs
279+
):
280+
input_token_tensor = response.outputs["num_input_tokens"]
281+
output_token_tensor = response.outputs["num_output_tokens"]
282+
283+
if input_token_tensor.data_type == tritonserver.DataType.UINT32:
284+
prompt_tokens_ptr = ctypes.cast(
285+
input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32)
286+
)
287+
prompt_tokens = prompt_tokens_ptr[0]
288+
289+
if output_token_tensor.data_type == tritonserver.DataType.UINT32:
290+
completion_tokens_ptr = ctypes.cast(
291+
output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32)
292+
)
293+
completion_tokens = completion_tokens_ptr[0]
294+
295+
if prompt_tokens is not None and completion_tokens is not None:
296+
total_tokens = prompt_tokens + completion_tokens
297+
return CompletionUsage(
298+
prompt_tokens=prompt_tokens,
299+
completion_tokens=completion_tokens,
300+
total_tokens=total_tokens,
301+
)
302+
303+
return None
304+
305+
224306
# TODO: Use tritonserver.InferenceResponse when support is published
225307
def _get_output(response: tritonserver._api._response.InferenceResponse) -> str:
226308
if "text_output" in response.outputs:

0 commit comments

Comments
 (0)