Skip to content

Commit 6b7f10c

Browse files
committed
Set max_tokens key name based on completion endpoint
Signed-off-by: Samuel Monson <[email protected]>
1 parent 7a0b160 commit 6b7f10c

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/guidellm/backend/openai.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
TEXT_COMPLETIONS_PATH = "/v1/completions"
3232
CHAT_COMPLETIONS_PATH = "/v1/chat/completions"
3333

34-
EndpointType = Literal["chat_completions", "models", "text_completions"]
35-
CHAT_COMPLETIONS: EndpointType = "chat_completions"
34+
CompletionEndpointType = Literal["text_completions", "chat_completions"]
35+
EndpointType = Union[Literal["models"], CompletionEndpointType]
36+
CHAT_COMPLETIONS: CompletionEndpointType = "chat_completions"
3637
MODELS: EndpointType = "models"
37-
TEXT_COMPLETIONS: EndpointType = "text_completions"
38+
TEXT_COMPLETIONS: CompletionEndpointType = "text_completions"
3839

3940

4041
@Backend.register("openai_http")
@@ -447,7 +448,7 @@ def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]:
447448

448449
def _completions_payload(
449450
self,
450-
endpoint_type: EndpointType,
451+
endpoint_type: CompletionEndpointType,
451452
orig_kwargs: Optional[dict],
452453
max_output_tokens: Optional[int],
453454
**kwargs,
@@ -467,7 +468,10 @@ def _completions_payload(
467468
self.__class__.__name__,
468469
max_output_tokens or self.max_output_tokens,
469470
)
470-
payload["max_tokens"] = max_output_tokens or self.max_output_tokens
471+
max_output_key = settings.openai.max_output_key.get(
472+
endpoint_type, "max_tokens"
473+
)
474+
payload[max_output_key] = max_output_tokens or self.max_output_tokens
471475

472476
if max_output_tokens:
473477
# only set stop and ignore_eos if max_output_tokens set at request level

src/guidellm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class OpenAISettings(BaseModel):
8888
base_url: str = "http://localhost:8000"
8989
max_output_tokens: int = 16384
9090
verify: bool = True
91+
max_output_key: dict[Literal["text_completions", "chat_completions"], str] = {
92+
"text_completions": "max_tokens",
93+
"chat_completions": "max_completion_tokens",
94+
}
9195

9296

9397
class ReportGenerationSettings(BaseModel):

0 commit comments

Comments
 (0)