Skip to content

Commit 30e20cb

Browse files
authored
fix: Support max_completion_tokens option in OpenAI frontend (#8226)
1 parent 4184494 commit 30e20cb

File tree

9 files changed

+145
-71
lines changed

9 files changed

+145
-71
lines changed

python/openai/README.md

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ completion = client.chat.completions.create(
216216
},
217217
{"role": "user", "content": "What are LLMs?"},
218218
],
219-
max_tokens=256,
219+
max_completion_tokens=256,
220220
)
221221

222222
print(completion.choices[0].message.content)
@@ -487,7 +487,7 @@ messages = [
487487
]
488488

489489
tool_calls = client.chat.completions.create(
490-
messages=messages, model=model, tools=tools, max_tokens=128
490+
messages=messages, model=model, tools=tools, max_completion_tokens=128
491491
)
492492
function_name = tool_calls.choices[0].message.tool_calls[0].function.name
493493
function_arguments = tool_calls.choices[0].message.tool_calls[0].function.arguments
@@ -504,31 +504,6 @@ function arguments: {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
504504
tool calling result: The weather in Dallas, Texas is 85 degrees fahrenheit. It is partly cloudly, with highs in the 90's.
505505
```
506506

507-
<!-- TODO: Remove this warning when the openai api supports the max_completion_tokens instead of max_tokens -->
508-
> [!WARNING]
509-
> When using LangChain to call the `v1/chat/completions` endpoint, you might encounter an exception related to `max_completion_tokens` if you have specified `max_tokens` in the request.
510-
>
511-
> Example: `openai.BadRequestError: Error code: 400 - {'object': 'error', 'message': "[{'type': 'extra_forbidden', 'loc': ('body', 'max_completion_tokens'), 'msg': 'Extra inputs are not permitted', 'input': 800}]", 'type': 'BadRequestError', 'param': None, 'code': 400}`
512-
>
513-
> This issue is due to an incompatibility between Triton's OpenAI API frontend and the latest OpenAI API. We are actively working to address this gap. A workaround is adding the `max_tokens` into the `model_kwargs` of the LangChain OpenAI request.
514-
>
515-
> Example:
516-
```python
517-
from langchain.llms import OpenAI
518-
519-
llm = OpenAI(
520-
model_name="llama-3.1-8b-instruct",
521-
temperature=0.0,
522-
model_kwargs={
523-
"max_tokens": 4096
524-
}
525-
)
526-
527-
response = llm("Write a short poem about a sunset.")
528-
print(response)
529-
530-
```
531-
532507
#### Named Tool Calling
533508

534509
The OpenAI frontend supports named function calling, utilizing guided decoding in the vLLM and TensorRT-LLM backends. Users can specify one of the tools in `tool_choice` to force the model to select a specific tool for function calling.
@@ -639,12 +614,12 @@ messages = [
639614
]
640615

641616
tool_calls = client.chat.completions.create(
642-
messages=messages, model=model, tools=tools, tool_choice=tool_choice, max_tokens=128
617+
messages=messages, model=model, tools=tools, tool_choice=tool_choice, max_completion_tokens=128
643618
)
644619
function_name = tool_calls.choices[0].message.tool_calls[0].function.name
645620
function_arguments = tool_calls.choices[0].message.tool_calls[0].function.arguments
646621

647-
print(f"function name: "{function_name}")
622+
print(f"function name: {function_name}")
648623
print(f"function arguments: {function_arguments}")
649624
print(f"tool calling result: {available_tools[function_name](**json.loads(function_arguments))}")
650625
```

python/openai/openai_frontend/engine/triton_engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
self,
103103
server: tritonserver.Server,
104104
tokenizer: str,
105+
default_max_tokens: int,
105106
backend: Optional[str] = None,
106107
lora_separator: Optional[str] = None,
107108
tool_call_parser: Optional[str] = None,
@@ -113,6 +114,7 @@ def __init__(
113114
# TODO: Reconsider name of "backend" vs. something like "request_format"
114115
self.backend = backend
115116
self.lora_separator = lora_separator
117+
self.default_max_tokens = default_max_tokens
116118

117119
# NOTE: Creation time and model metadata will be static at startup for
118120
# now, and won't account for dynamically loading/unloading models.
@@ -184,7 +186,9 @@ async def chat(
184186

185187
# Convert to Triton request format and perform inference
186188
responses = metadata.model.async_infer(
187-
metadata.request_converter(metadata.model, prompt, request, lora_name)
189+
metadata.request_converter(
190+
metadata.model, prompt, request, lora_name, self.default_max_tokens
191+
)
188192
)
189193

190194
# Prepare and send responses back to client in OpenAI format
@@ -302,7 +306,11 @@ async def completion(
302306
# Convert to Triton request format and perform inference
303307
responses = metadata.model.async_infer(
304308
metadata.request_converter(
305-
metadata.model, request.prompt, request, lora_name
309+
metadata.model,
310+
request.prompt,
311+
request,
312+
lora_name,
313+
self.default_max_tokens,
306314
)
307315
)
308316

python/openai/openai_frontend/engine/utils/triton.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def _create_vllm_inference_request(
4646
prompt,
4747
request: CreateChatCompletionRequest | CreateCompletionRequest,
4848
lora_name: str | None,
49+
default_max_tokens: int,
4950
):
5051
inputs = {}
5152
# Exclude non-sampling parameters so they aren't passed to vLLM
@@ -67,6 +68,9 @@ def _create_vllm_inference_request(
6768
"function_call",
6869
"functions",
6970
"suffix",
71+
"max_completion_tokens",
72+
# will be handled explicitly
73+
"max_tokens",
7074
}
7175

7276
# NOTE: The exclude_none is important, as internals may not support
@@ -75,6 +79,23 @@ def _create_vllm_inference_request(
7579
exclude=excludes,
7680
exclude_none=True,
7781
)
82+
83+
# Indicates CreateChatCompletionRequest
84+
if hasattr(request, "max_completion_tokens"):
85+
if request.max_completion_tokens is not None:
86+
sampling_parameters["max_tokens"] = request.max_completion_tokens
87+
# Fallback to deprecated request.max_tokens
88+
elif request.max_tokens is not None:
89+
sampling_parameters["max_tokens"] = request.max_tokens
90+
# If neither is set, use a default value for max_tokens
91+
else:
92+
sampling_parameters["max_tokens"] = default_max_tokens
93+
# Indicates CreateCompletionRequest
94+
elif request.max_tokens is not None:
95+
sampling_parameters["max_tokens"] = request.max_tokens
96+
else:
97+
sampling_parameters["max_tokens"] = default_max_tokens
98+
7899
if lora_name is not None:
79100
sampling_parameters["lora_name"] = lora_name
80101
sampling_parameters = json.dumps(sampling_parameters)
@@ -108,15 +129,31 @@ def _create_trtllm_inference_request(
108129
prompt,
109130
request: CreateChatCompletionRequest | CreateCompletionRequest,
110131
lora_name: str | None,
132+
default_max_tokens: int,
111133
):
112134
if lora_name is not None:
113135
raise Exception("LoRA selection is currently not supported for TRT-LLM backend")
114136

115137
inputs = {}
116138
inputs["text_input"] = [[prompt]]
117139
inputs["stream"] = np.bool_([[request.stream]])
118-
if request.max_tokens:
140+
141+
# Indicates CreateChatCompletionRequest
142+
if hasattr(request, "max_completion_tokens"):
143+
if request.max_completion_tokens is not None:
144+
inputs["max_tokens"] = np.int32([[request.max_completion_tokens]])
145+
# Fallback to deprecated request.max_tokens
146+
elif request.max_tokens is not None:
147+
inputs["max_tokens"] = np.int32([[request.max_tokens]])
148+
# If neither is set, use a default value for max_tokens
149+
else:
150+
inputs["max_tokens"] = np.int32([[default_max_tokens]])
151+
# Indicates CreateCompletionRequest
152+
elif request.max_tokens is not None:
119153
inputs["max_tokens"] = np.int32([[request.max_tokens]])
154+
else:
155+
inputs["max_tokens"] = np.int32([[default_max_tokens]])
156+
120157
if request.stop:
121158
if isinstance(request.stop, str):
122159
request.stop = [request.stop]

python/openai/openai_frontend/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,13 @@ def parse_args():
143143
help="The path to the custom Jinja chat template file. This is useful if you'd like to use a different chat template than the one provided by the model.",
144144
)
145145

146+
triton_group.add_argument(
147+
"--default-max-tokens",
148+
type=int,
149+
default=16,
150+
help="The default maximum number of tokens to generate if not specified in the request. The default is 16.",
151+
)
152+
146153
# OpenAI-Compatible Frontend (FastAPI)
147154
openai_group = parser.add_argument_group("Triton OpenAI-Compatible Frontend")
148155
openai_group.add_argument(
@@ -199,6 +206,7 @@ def main():
199206
lora_separator=args.lora_separator,
200207
tool_call_parser=args.tool_call_parser,
201208
chat_template=args.chat_template,
209+
default_max_tokens=args.default_max_tokens,
202210
)
203211

204212
# Attach TritonLLMEngine as the backbone for inference and model management

python/openai/openai_frontend/schemas/openai.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class CreateCompletionRequest(BaseModel):
103103
description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n",
104104
)
105105
max_tokens: Optional[conint(ge=0)] = Field(
106-
16,
106+
None,
107107
description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n",
108108
examples=[16],
109109
)
@@ -526,14 +526,6 @@ class Logprobs2(BaseModel):
526526
)
527527

528528

529-
class ChatCompletionFinishReason(Enum):
530-
stop = "stop"
531-
length = "length"
532-
tool_calls = "tool_calls"
533-
content_filter = "content_filter"
534-
function_call = "function_call"
535-
536-
537529
class ChatCompletionStreamingResponseChoice(BaseModel):
538530
delta: ChatCompletionStreamResponseDelta
539531
logprobs: Optional[Logprobs2] = Field(
@@ -850,11 +842,15 @@ class CreateChatCompletionRequest(BaseModel):
850842
None,
851843
description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.",
852844
)
853-
# TODO: Consider new max_completion_tokens field in the future: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_completion_tokens
854-
max_tokens: Optional[conint(ge=0)] = Field(
855-
16,
845+
max_completion_tokens: Optional[conint(ge=0)] = Field(
846+
None,
856847
description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n",
857848
)
849+
# TODO: Remove support for max_tokens field in the future: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_completion_tokens
850+
max_tokens: Optional[conint(ge=0)] = Field(
851+
None,
852+
description="DEPRECATED: Use `max_completion_tokens` instead. The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n",
853+
)
858854
# TODO: Extension, flesh out description and defaults
859855
min_tokens: Optional[conint(ge=0)] = Field(
860856
None,
@@ -871,7 +867,7 @@ class CreateChatCompletionRequest(BaseModel):
871867
)
872868
response_format: Optional[ResponseFormat] = Field(
873869
None,
874-
description='An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n',
870+
description='An object specifying the format that the model must output. Compatible with [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_completion_tokens` or the conversation exceeded the max context length.\n',
875871
)
876872
seed: Optional[conint(ge=-9223372036854775808, le=9223372036854775807)] = Field(
877873
None,

python/openai/tests/test_chat_completions.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def test_chat_completions_user_prompt_dict(self, client, model: str):
123123
[
124124
("temperature", 0.7),
125125
("max_tokens", 10),
126+
("max_completion_tokens", 10),
126127
("top_p", 0.9),
127128
("frequency_penalty", 0.5),
128129
("presence_penalty", 0.2),
@@ -172,6 +173,7 @@ def test_chat_completions_sampling_parameters(
172173
("temperature", 2.1),
173174
("temperature", -0.1),
174175
("max_tokens", -1),
176+
("max_completion_tokens", -1),
175177
("top_p", 1.1),
176178
("frequency_penalty", 3),
177179
("frequency_penalty", -3),
@@ -199,14 +201,21 @@ def test_chat_completions_invalid_sampling_parameters(
199201
assert response.status_code == 422
200202

201203
# Simple tests to verify max_tokens roughly behaves as expected
204+
@pytest.mark.parametrize(
205+
"max_tokens_key",
206+
[
207+
"max_tokens",
208+
"max_completion_tokens",
209+
],
210+
)
202211
def test_chat_completions_max_tokens(
203-
self, client, model: str, messages: List[dict]
212+
self, client, max_tokens_key, model: str, messages: List[dict]
204213
):
205214
responses = []
206-
payload = {"model": model, "messages": messages, "max_tokens": 1}
215+
payload = {"model": model, "messages": messages}
207216

208-
# Send two requests with max_tokens = 1 to check their similarity
209-
payload["max_tokens"] = 1
217+
# Send two requests with max_tokens/max_completion_tokens = 1 to check their similarity
218+
payload[max_tokens_key] = 1
210219
responses.append(
211220
client.post(
212221
"/v1/chat/completions",
@@ -219,8 +228,8 @@ def test_chat_completions_max_tokens(
219228
json=payload,
220229
)
221230
)
222-
# Send one requests with larger max_tokens to check its dis-similarity
223-
payload["max_tokens"] = 100
231+
# Send one requests with larger max_tokens/max_completion_tokens to check its dis-similarity
232+
payload[max_tokens_key] = 100
224233
responses.append(
225234
client.post(
226235
"/v1/chat/completions",
@@ -245,6 +254,30 @@ def test_chat_completions_max_tokens(
245254
assert len(response1_text) == len(response2_text) == 1
246255
assert len(response3_text) > len(response1_text)
247256

257+
def test_chat_completions_max_completion_tokens_precedence(
258+
self, client, model: str, messages: List[dict]
259+
):
260+
payload = {
261+
"model": model,
262+
"messages": messages,
263+
"max_tokens": 50, # Higher value for max_tokens
264+
"max_completion_tokens": 1, # Lower, expected to take precedence
265+
}
266+
267+
response = client.post(
268+
"/v1/chat/completions",
269+
json=payload,
270+
)
271+
272+
print("Response:", response.json())
273+
assert response.status_code == 200
274+
275+
response_text_words = (
276+
response.json()["choices"][0]["message"]["content"].strip().split()
277+
)
278+
# Check if the number of words is around max_completion_tokens
279+
assert len(response_text_words) == 1
280+
248281
@pytest.mark.parametrize(
249282
"temperature",
250283
[0.0, 1.0],
@@ -260,7 +293,7 @@ def test_chat_completions_temperature_vllm(
260293
payload = {
261294
"model": model,
262295
"messages": messages,
263-
"max_tokens": 256,
296+
"max_completion_tokens": 256,
264297
"temperature": temperature,
265298
}
266299

@@ -321,7 +354,7 @@ def test_chat_completions_temperature_tensorrtllm(
321354
"model": model,
322355
"messages": messages,
323356
# Increase token length to allow more room for variability
324-
"max_tokens": 200,
357+
"max_completion_tokens": 200,
325358
"temperature": 0.0,
326359
# TRT-LLM requires certain settings of `top_k` / `top_p` to
327360
# respect changes in `temperature`
@@ -376,7 +409,7 @@ def test_chat_completions_seed(self, client, model: str, messages: List[dict]):
376409
"model": model,
377410
"messages": messages,
378411
# Increase token length to allow more room for variability
379-
"max_tokens": 200,
412+
"max_completion_tokens": 200,
380413
"seed": 1,
381414
}
382415
payload2 = copy.deepcopy(payload1)
@@ -559,7 +592,12 @@ def test_chat_completions_custom_tokenizer(
559592

560593
responses = []
561594
with TestClient(app_local) as client_local, TestClient(app_hf) as client_hf:
562-
payload = {"model": model, "messages": messages, "temperature": 0}
595+
payload = {
596+
"model": model,
597+
"messages": messages,
598+
"temperature": 0,
599+
"seed": 0,
600+
}
563601
responses.append(client_local.post("/v1/chat/completions", json=payload))
564602
responses.append(client_hf.post("/v1/chat/completions", json=payload))
565603

0 commit comments

Comments
 (0)