Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
follow_redirects: Optional[bool] = None,
max_output_tokens: Optional[int] = None,
extra_query: Optional[dict] = None,
extra_body: Optional[dict] = None,
):
super().__init__(type_="openai_http")
self._target = target or settings.openai.base_url
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
else settings.openai.max_output_tokens
)
self.extra_query = extra_query
self.extra_body = extra_body
self._async_client: Optional[httpx.AsyncClient] = None

@property
Expand Down Expand Up @@ -242,7 +244,9 @@ async def text_completions( # type: ignore[override]

headers = self._headers()
params = self._params(TEXT_COMPLETIONS)
body = self._body(TEXT_COMPLETIONS)
payload = self._completions_payload(
body=body,
orig_kwargs=kwargs,
max_output_tokens=output_token_count,
prompt=prompt,
Expand Down Expand Up @@ -317,10 +321,12 @@ async def chat_completions( # type: ignore[override]
logger.debug("{} invocation with args: {}", self.__class__.__name__, locals())
headers = self._headers()
params = self._params(CHAT_COMPLETIONS)
body = self._body(CHAT_COMPLETIONS)
messages = (
content if raw_content else self._create_chat_messages(content=content)
)
payload = self._completions_payload(
body=body,
orig_kwargs=kwargs,
max_output_tokens=output_token_count,
messages=messages,
Expand Down Expand Up @@ -396,10 +402,28 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:

return self.extra_query

def _body(self, endpoint_type: EndpointType) -> dict[str, str]:
if self.extra_body is None:
return {}

if (
CHAT_COMPLETIONS in self.extra_body
or MODELS in self.extra_body
or TEXT_COMPLETIONS in self.extra_body
):
return self.extra_body.get(endpoint_type, {})

return self.extra_body

def _completions_payload(
self, orig_kwargs: Optional[dict], max_output_tokens: Optional[int], **kwargs
self,
body: Optional[dict],
orig_kwargs: Optional[dict],
max_output_tokens: Optional[int],
**kwargs,
) -> dict:
payload = orig_kwargs or {}
payload = body or {}
payload.update(orig_kwargs or {})
payload.update(kwargs)
payload["model"] = self.model
payload["stream"] = True
Expand Down
4 changes: 3 additions & 1 deletion src/guidellm/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
class SyntheticDatasetCreator(DatasetCreator):
@classmethod
def is_supported(
cls, data: Any, data_args: Optional[dict[str, Any]] # noqa: ARG003
cls,
data: Any,
data_args: Optional[dict[str, Any]], # noqa: ARG003
) -> bool:
if (
isinstance(data, Path)
Expand Down
Loading