diff --git a/src/guidellm/backend/openai.py b/src/guidellm/backend/openai.py index b5efc847..4eb6ae00 100644 --- a/src/guidellm/backend/openai.py +++ b/src/guidellm/backend/openai.py @@ -70,6 +70,14 @@ class OpenAIHTTPBackend(Backend): the values of these keys will be used as the parameters for the respective endpoint. If not provided, no extra query parameters are added. + :param extra_body: Body parameters to include in requests to the OpenAI server. + If "chat_completions", "models", or "text_completions" are included as keys, + the values of these keys will be included in the body for the respective + endpoint. + If not provided, no extra body parameters are added. + :param remove_from_body: Parameters that should be removed from the body of each + request. + If not provided, no parameters are removed from the body. """ def __init__( @@ -85,6 +93,7 @@ def __init__( max_output_tokens: Optional[int] = None, extra_query: Optional[dict] = None, extra_body: Optional[dict] = None, + remove_from_body: Optional[list[str]] = None, ): super().__init__(type_="openai_http") self._target = target or settings.openai.base_url @@ -122,6 +131,7 @@ def __init__( ) self.extra_query = extra_query self.extra_body = extra_body + self.remove_from_body = remove_from_body self._async_client: Optional[httpx.AsyncClient] = None @property @@ -253,9 +263,8 @@ async def text_completions( # type: ignore[override] headers = self._headers() params = self._params(TEXT_COMPLETIONS) - body = self._body(TEXT_COMPLETIONS) payload = self._completions_payload( - body=body, + endpoint_type=TEXT_COMPLETIONS, orig_kwargs=kwargs, max_output_tokens=output_token_count, prompt=prompt, @@ -330,12 +339,11 @@ async def chat_completions( # type: ignore[override] logger.debug("{} invocation with args: {}", self.__class__.__name__, locals()) headers = self._headers() params = self._params(CHAT_COMPLETIONS) - body = self._body(CHAT_COMPLETIONS) messages = ( content if raw_content else self._create_chat_messages(content=content) ) payload = self._completions_payload( - body=body, + endpoint_type=CHAT_COMPLETIONS, orig_kwargs=kwargs, max_output_tokens=output_token_count, messages=messages, @@ -411,7 +419,7 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]: return self.extra_query - def _body(self, endpoint_type: EndpointType) -> dict[str, str]: + def _extra_body(self, endpoint_type: EndpointType) -> dict[str, Any]: if self.extra_body is None: return {} @@ -426,12 +434,12 @@ def _body(self, endpoint_type: EndpointType) -> dict[str, str]: def _completions_payload( self, - body: Optional[dict], + endpoint_type: EndpointType, orig_kwargs: Optional[dict], max_output_tokens: Optional[int], **kwargs, ) -> dict: - payload = body or {} + payload = self._extra_body(endpoint_type) payload.update(orig_kwargs or {}) payload.update(kwargs) payload["model"] = self.model @@ -455,6 +463,10 @@ def _completions_payload( payload["stop"] = None payload["ignore_eos"] = True + if self.remove_from_body: + for key in self.remove_from_body: + payload.pop(key, None) + return payload @staticmethod