diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index 680578cc..a99962e9 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -31,10 +31,11 @@ TEXT_COMPLETIONS_PATH = "/v1/completions" CHAT_COMPLETIONS_PATH = "/v1/chat/completions" -EndpointType = Literal["chat_completions", "models", "text_completions"] -CHAT_COMPLETIONS: EndpointType = "chat_completions" +CompletionEndpointType = Literal["text_completions", "chat_completions"] +EndpointType = Union[Literal["models"], CompletionEndpointType] +CHAT_COMPLETIONS: CompletionEndpointType = "chat_completions" MODELS: EndpointType = "models" -TEXT_COMPLETIONS: EndpointType = "text_completions" +TEXT_COMPLETIONS: CompletionEndpointType = "text_completions" @Backend.register("openai_http") @@ -447,7 +448,7 @@ def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]: def _completions_payload( self, - endpoint_type: EndpointType, + endpoint_type: CompletionEndpointType, orig_kwargs: Optional[dict], max_output_tokens: Optional[int], **kwargs, @@ -467,8 +468,10 @@ def _completions_payload( self.__class__.__name__, max_output_tokens or self.max_output_tokens, ) - payload["max_tokens"] = max_output_tokens or self.max_output_tokens - payload["max_completion_tokens"] = payload["max_tokens"] + max_output_key = settings.openai.max_output_key.get( + endpoint_type, "max_tokens" + ) + payload[max_output_key] = max_output_tokens or self.max_output_tokens if max_output_tokens: # only set stop and ignore_eos if max_output_tokens set at request level diff --git a/src/guidellm/config.py b/src/guidellm/config.py index d9dcef23..0b195c9f 100644 --- a/src/guidellm/config.py +++ b/src/guidellm/config.py @@ -88,6 +88,10 @@ class OpenAISettings(BaseModel): base_url: str = "http://localhost:8000" max_output_tokens: int = 16384 verify: bool = True + max_output_key: dict[Literal["text_completions", "chat_completions"], str] = { + "text_completions": "max_tokens", + "chat_completions": "max_completion_tokens", + } class ReportGenerationSettings(BaseModel): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a0457b6f..6ac25c4d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -132,7 +132,6 @@ async def _mock_completions_response(request) -> AsyncIterable[str]: assert payload["stream_options"] == {"include_usage": True} assert payload["prompt"] is not None assert len(payload["prompt"]) > 0 - assert payload["max_completion_tokens"] > 0 assert payload["max_tokens"] > 0 return httpx.Response( # type: ignore @@ -141,7 +140,7 @@ async def _mock_completions_response(request) -> AsyncIterable[str]: type_="text", prompt=payload["prompt"], output_token_count=( - payload["max_completion_tokens"] + payload["max_tokens"] if payload.get("ignore_eos", False) else None ), @@ -162,7 +161,6 @@ async def _mock_chat_completions_response(request): assert payload["messages"] is not None assert len(payload["messages"]) > 0 assert payload["max_completion_tokens"] > 0 - assert payload["max_tokens"] > 0 return httpx.Response( # type: ignore 200,