Skip to content

Commit 8b5ec4f

Browse files
committed
Add remove_from_body to OpenAIHTTPBackend
1 parent d401ebd commit 8b5ec4f

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

src/guidellm/backend/openai.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ class OpenAIHTTPBackend(Backend):
7070
the values of these keys will be used as the parameters for the respective
7171
endpoint.
7272
If not provided, no extra query parameters are added.
73+
:param extra_body: Body parameters to include in requests to the OpenAI server.
74+
If "chat_completions", "models", or "text_completions" are included as keys,
75+
the values of these keys will be included in the body for the respective
76+
endpoint.
77+
If not provided, no extra body parameters are added.
78+
:param remove_from_body: Parameters that should be removed from the body of each
79+
request.
80+
If not provided, no parameters are removed from the body.
7381
"""
7482

7583
def __init__(
@@ -85,6 +93,7 @@ def __init__(
8593
max_output_tokens: Optional[int] = None,
8694
extra_query: Optional[dict] = None,
8795
extra_body: Optional[dict] = None,
96+
remove_from_body: Optional[list[str]] = None,
8897
):
8998
super().__init__(type_="openai_http")
9099
self._target = target or settings.openai.base_url
@@ -122,6 +131,7 @@ def __init__(
122131
)
123132
self.extra_query = extra_query
124133
self.extra_body = extra_body
134+
self.remove_from_body = remove_from_body
125135
self._async_client: Optional[httpx.AsyncClient] = None
126136

127137
@property
@@ -244,9 +254,8 @@ async def text_completions( # type: ignore[override]
244254

245255
headers = self._headers()
246256
params = self._params(TEXT_COMPLETIONS)
247-
body = self._body(TEXT_COMPLETIONS)
248257
payload = self._completions_payload(
249-
body=body,
258+
endpoint_type=TEXT_COMPLETIONS,
250259
orig_kwargs=kwargs,
251260
max_output_tokens=output_token_count,
252261
prompt=prompt,
@@ -321,12 +330,11 @@ async def chat_completions( # type: ignore[override]
321330
logger.debug("{} invocation with args: {}", self.__class__.__name__, locals())
322331
headers = self._headers()
323332
params = self._params(CHAT_COMPLETIONS)
324-
body = self._body(CHAT_COMPLETIONS)
325333
messages = (
326334
content if raw_content else self._create_chat_messages(content=content)
327335
)
328336
payload = self._completions_payload(
329-
body=body,
337+
endpoint_type=CHAT_COMPLETIONS,
330338
orig_kwargs=kwargs,
331339
max_output_tokens=output_token_count,
332340
messages=messages,
@@ -402,7 +410,7 @@ def _params(self, endpoint_type: EndpointType) -> dict[str, str]:
402410

403411
return self.extra_query
404412

405-
def _body(self, endpoint_type: EndpointType) -> dict[str, str]:
413+
def _extra_body(self, endpoint_type: EndpointType) -> dict[str, str]:
406414
if self.extra_body is None:
407415
return {}
408416

@@ -417,12 +425,12 @@ def _body(self, endpoint_type: EndpointType) -> dict[str, str]:
417425

418426
def _completions_payload(
419427
self,
420-
body: Optional[dict],
428+
endpoint_type: EndpointType,
421429
orig_kwargs: Optional[dict],
422430
max_output_tokens: Optional[int],
423431
**kwargs,
424432
) -> dict:
425-
payload = body or {}
433+
payload = self._extra_body(endpoint_type)
426434
payload.update(orig_kwargs or {})
427435
payload.update(kwargs)
428436
payload["model"] = self.model
@@ -446,6 +454,10 @@ def _completions_payload(
446454
payload["stop"] = None
447455
payload["ignore_eos"] = True
448456

457+
if self.remove_from_body:
458+
for key in self.remove_from_body:
459+
payload.pop(key, None)
460+
449461
return payload
450462

451463
@staticmethod

0 commit comments

Comments
 (0)