Skip to content

Commit 121dcdc

Browse files
Configurable max_tokens/max_completion_tokens key (#399)
## Summary <!-- Include a short paragraph of the changes introduced in this PR. If this PR requires additional context or rationale, explain why the changes are necessary. --> Makes the `max_tokens` request key configurable through an environment variable per endpoint type. Defaults to `max_tokens` for legacy `completions` and `max_completion_tokens` for `chat/completions` ## Details <!-- Provide a detailed list of all changes introduced in this pull request. --> - Add the `GUIDELLM__OPENAI__MAX_OUTPUT_KEY` config option which is a dict mapping from route name -> output tokens key. Default is `{"text_completions": "max_tokens", "chat_completions": "max_completion_tokens"}` ## Test Plan <!-- List the steps needed to test this PR. --> - ## Related Issues <!-- Link any relevant issues that this PR addresses. --> - Closes #395 - Closes #269 - Related #210 --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [ ] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: Samuel Monson <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent a24a22d commit 121dcdc

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

src/guidellm/backend/openai.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
TEXT_COMPLETIONS_PATH = "/v1/completions"
3232
CHAT_COMPLETIONS_PATH = "/v1/chat/completions"
3333

34-
EndpointType = Literal["chat_completions", "models", "text_completions"]
35-
CHAT_COMPLETIONS: EndpointType = "chat_completions"
34+
CompletionEndpointType = Literal["text_completions", "chat_completions"]
35+
EndpointType = Union[Literal["models"], CompletionEndpointType]
36+
CHAT_COMPLETIONS: CompletionEndpointType = "chat_completions"
3637
MODELS: EndpointType = "models"
37-
TEXT_COMPLETIONS: EndpointType = "text_completions"
38+
TEXT_COMPLETIONS: CompletionEndpointType = "text_completions"
3839

3940

4041
@Backend.register("openai_http")
@@ -447,7 +448,7 @@ def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]:
447448

448449
def _completions_payload(
449450
self,
450-
endpoint_type: EndpointType,
451+
endpoint_type: CompletionEndpointType,
451452
orig_kwargs: Optional[dict],
452453
max_output_tokens: Optional[int],
453454
**kwargs,
@@ -467,8 +468,10 @@ def _completions_payload(
467468
self.__class__.__name__,
468469
max_output_tokens or self.max_output_tokens,
469470
)
470-
payload["max_tokens"] = max_output_tokens or self.max_output_tokens
471-
payload["max_completion_tokens"] = payload["max_tokens"]
471+
max_output_key = settings.openai.max_output_key.get(
472+
endpoint_type, "max_tokens"
473+
)
474+
payload[max_output_key] = max_output_tokens or self.max_output_tokens
472475

473476
if max_output_tokens:
474477
# only set stop and ignore_eos if max_output_tokens set at request level

src/guidellm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class OpenAISettings(BaseModel):
8888
base_url: str = "http://localhost:8000"
8989
max_output_tokens: int = 16384
9090
verify: bool = True
91+
max_output_key: dict[Literal["text_completions", "chat_completions"], str] = {
92+
"text_completions": "max_tokens",
93+
"chat_completions": "max_completion_tokens",
94+
}
9195

9296

9397
class ReportGenerationSettings(BaseModel):

tests/unit/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ async def _mock_completions_response(request) -> AsyncIterable[str]:
132132
assert payload["stream_options"] == {"include_usage": True}
133133
assert payload["prompt"] is not None
134134
assert len(payload["prompt"]) > 0
135-
assert payload["max_completion_tokens"] > 0
136135
assert payload["max_tokens"] > 0
137136

138137
return httpx.Response( # type: ignore
@@ -141,7 +140,7 @@ async def _mock_completions_response(request) -> AsyncIterable[str]:
141140
type_="text",
142141
prompt=payload["prompt"],
143142
output_token_count=(
144-
payload["max_completion_tokens"]
143+
payload["max_tokens"]
145144
if payload.get("ignore_eos", False)
146145
else None
147146
),
@@ -162,7 +161,6 @@ async def _mock_chat_completions_response(request):
162161
assert payload["messages"] is not None
163162
assert len(payload["messages"]) > 0
164163
assert payload["max_completion_tokens"] > 0
165-
assert payload["max_tokens"] > 0
166164

167165
return httpx.Response( # type: ignore
168166
200,

0 commit comments

Comments
 (0)