Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
TEXT_COMPLETIONS_PATH = "/v1/completions"
CHAT_COMPLETIONS_PATH = "/v1/chat/completions"

EndpointType = Literal["chat_completions", "models", "text_completions"]
CHAT_COMPLETIONS: EndpointType = "chat_completions"
CompletionEndpointType = Literal["text_completions", "chat_completions"]
EndpointType = Union[Literal["models"], CompletionEndpointType]
CHAT_COMPLETIONS: CompletionEndpointType = "chat_completions"
MODELS: EndpointType = "models"
TEXT_COMPLETIONS: EndpointType = "text_completions"
TEXT_COMPLETIONS: CompletionEndpointType = "text_completions"


@Backend.register("openai_http")
Expand Down Expand Up @@ -447,7 +448,7 @@ def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]:

def _completions_payload(
self,
endpoint_type: EndpointType,
endpoint_type: CompletionEndpointType,
orig_kwargs: Optional[dict],
max_output_tokens: Optional[int],
**kwargs,
Expand All @@ -467,8 +468,10 @@ def _completions_payload(
self.__class__.__name__,
max_output_tokens or self.max_output_tokens,
)
payload["max_tokens"] = max_output_tokens or self.max_output_tokens
payload["max_completion_tokens"] = payload["max_tokens"]
max_output_key = settings.openai.max_output_key.get(
endpoint_type, "max_tokens"
)
payload[max_output_key] = max_output_tokens or self.max_output_tokens

if max_output_tokens:
# only set stop and ignore_eos if max_output_tokens set at request level
Expand Down
4 changes: 4 additions & 0 deletions src/guidellm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class OpenAISettings(BaseModel):
base_url: str = "http://localhost:8000"
max_output_tokens: int = 16384
verify: bool = True
max_output_key: dict[Literal["text_completions", "chat_completions"], str] = {
"text_completions": "max_tokens",
"chat_completions": "max_completion_tokens",
}


class ReportGenerationSettings(BaseModel):
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ async def _mock_completions_response(request) -> AsyncIterable[str]:
assert payload["stream_options"] == {"include_usage": True}
assert payload["prompt"] is not None
assert len(payload["prompt"]) > 0
assert payload["max_completion_tokens"] > 0
assert payload["max_tokens"] > 0

return httpx.Response( # type: ignore
Expand All @@ -141,7 +140,7 @@ async def _mock_completions_response(request) -> AsyncIterable[str]:
type_="text",
prompt=payload["prompt"],
output_token_count=(
payload["max_completion_tokens"]
payload["max_tokens"]
if payload.get("ignore_eos", False)
else None
),
Expand All @@ -162,7 +161,6 @@ async def _mock_chat_completions_response(request):
assert payload["messages"] is not None
assert len(payload["messages"]) > 0
assert payload["max_completion_tokens"] > 0
assert payload["max_tokens"] > 0

return httpx.Response( # type: ignore
200,
Expand Down