Skip to content

Commit 133c2ec

Browse files
committed
Split formatter across multiple stages
Signed-off-by: Samuel Monson <[email protected]>
1 parent 094fc93 commit 133c2ec

File tree

9 files changed

+455
-445
lines changed

9 files changed

+455
-445
lines changed

src/guidellm/backends/openai.py

Lines changed: 217 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212

1313
import asyncio
1414
import time
15-
from collections.abc import AsyncIterator
15+
from collections.abc import AsyncIterator, Callable
1616
from typing import Any
1717

1818
import httpx
1919

2020
from guidellm.backends.backend import Backend
2121
from guidellm.backends.response_handlers import GenerationResponseHandlerFactory
22-
from guidellm.schemas import GenerationRequest, GenerationResponse, RequestInfo
22+
from guidellm.schemas import (
23+
GenerationRequest,
24+
GenerationRequestArguments,
25+
GenerationResponse,
26+
RequestInfo,
27+
)
2328

2429
__all__ = ["OpenAIHTTPBackend"]
2530

@@ -59,6 +64,10 @@ def __init__(
5964
follow_redirects: bool = True,
6065
verify: bool = False,
6166
validate_backend: bool | str | dict[str, Any] = True,
67+
stream: bool = True,
68+
extras: dict[str, Any] | GenerationRequestArguments | None = None,
69+
max_tokens: int | None = None,
70+
max_completion_tokens: int | None = None,
6271
):
6372
"""
6473
Initialize OpenAI HTTP backend with server configuration.
@@ -96,11 +105,28 @@ def __init__(
96105
self.validate_backend: dict[str, Any] | None = self._resolve_validate_kwargs(
97106
validate_backend
98107
)
108+
self.stream: bool = stream
109+
self.extras = (
110+
GenerationRequestArguments(**extras)
111+
if extras and isinstance(extras, dict)
112+
else extras
113+
)
114+
self.max_tokens: int | None = max_tokens or max_completion_tokens
99115

100116
# Runtime state
101117
self._in_process = False
102118
self._async_client: httpx.AsyncClient | None = None
103119

120+
# TODO: Find a better way to register formatters
121+
self.request_formatters: dict[
122+
str, Callable[[GenerationRequest], GenerationRequestArguments]
123+
] = {
124+
"text_completions": self.formatter_text_completions,
125+
"chat_completions": self.formatter_chat_completions,
126+
"audio_transcriptions": self.formatter_audio_transcriptions,
127+
"audio_translations": self.formatter_audio_transcriptions,
128+
}
129+
104130
@property
105131
def info(self) -> dict[str, Any]:
106132
"""
@@ -227,31 +253,35 @@ async def resolve( # type: ignore[override]
227253
if history is not None:
228254
raise NotImplementedError("Multi-turn requests not yet supported")
229255

256+
arguments: GenerationRequestArguments = self.request_formatters[
257+
request.request_type
258+
](request)
259+
230260
if (request_path := self.api_routes.get(request.request_type)) is None:
231261
raise ValueError(f"Unsupported request type '{request.request_type}'")
232262

233263
request_url = f"{self.target}/{request_path}"
234264
request_files = (
235265
{
236266
key: tuple(value) if isinstance(value, list) else value
237-
for key, value in request.arguments.files.items()
267+
for key, value in arguments.files.items()
238268
}
239-
if request.arguments.files
269+
if arguments.files
240270
else None
241271
)
242-
request_json = request.arguments.body if not request_files else None
243-
request_data = request.arguments.body if request_files else None
272+
request_json = arguments.body if not request_files else None
273+
request_data = arguments.body if request_files else None
244274
response_handler = GenerationResponseHandlerFactory.create(
245275
request.request_type, handler_overrides=self.response_handlers
246276
)
247277

248-
if not request.arguments.stream:
278+
if not arguments.stream:
249279
request_info.timings.request_start = time.time()
250280
response = await self._async_client.request(
251-
request.arguments.method or "POST",
281+
arguments.method or "POST",
252282
request_url,
253-
params=request.arguments.params,
254-
headers=request.arguments.headers,
283+
params=arguments.params,
284+
headers=arguments.headers,
255285
json=request_json,
256286
data=request_data,
257287
files=request_files,
@@ -266,10 +296,10 @@ async def resolve( # type: ignore[override]
266296
request_info.timings.request_start = time.time()
267297

268298
async with self._async_client.stream(
269-
request.arguments.method or "POST",
299+
arguments.method or "POST",
270300
request_url,
271-
params=request.arguments.params,
272-
headers=request.arguments.headers,
301+
params=arguments.params,
302+
headers=arguments.headers,
273303
json=request_json,
274304
data=request_data,
275305
files=request_files,
@@ -332,3 +362,177 @@ def _resolve_validate_kwargs(
332362
validate_kwargs["method"] = "GET"
333363

334364
return validate_kwargs
365+
366+
def formatter_text_completions(
367+
self, data: GenerationRequest
368+
) -> GenerationRequestArguments:
369+
arguments: GenerationRequestArguments = GenerationRequestArguments()
370+
arguments.body = {} # The type checker works better setting this field here
371+
372+
# Add model
373+
if self.model is not None:
374+
arguments.body["model"] = self.model
375+
376+
# Configure streaming
377+
if self.stream:
378+
arguments.stream = True
379+
arguments.body["stream"] = True
380+
arguments.body["stream_options"] = {"include_usage": True}
381+
382+
# Handle output tokens
383+
if data.output_metrics.text_tokens:
384+
arguments.body["max_tokens"] = data.output_metrics.text_tokens
385+
arguments.body["stop"] = None
386+
arguments.body["ignore_eos"] = True
387+
elif self.max_tokens is not None:
388+
arguments.body["max_tokens"] = self.max_tokens
389+
390+
# Apply extra arguments
391+
if self.extras:
392+
arguments.model_combine(self.extras)
393+
394+
# Build prompt
395+
prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre)
396+
text = "".join(txt for txt in data.columns.get("text_column", []) if txt)
397+
if prefix or text:
398+
prompt = prefix + text
399+
arguments.body["prompt"] = prompt
400+
401+
return arguments
402+
403+
def formatter_chat_completions( # noqa: C901, PLR0912, PLR0915
404+
self, data: GenerationRequest
405+
) -> GenerationRequestArguments:
406+
arguments = GenerationRequestArguments()
407+
arguments.body = {} # The type checker works best with body assigned here
408+
409+
# Add model
410+
if self.model is not None:
411+
arguments.body["model"] = self.model
412+
413+
# Configure streaming
414+
if self.stream:
415+
arguments.stream = True
416+
arguments.body["stream"] = True
417+
arguments.body["stream_options"] = {"include_usage": True}
418+
419+
# Handle output tokens
420+
if data.output_metrics.text_tokens:
421+
arguments.body.update(
422+
{
423+
"max_completion_tokens": data.output_metrics.text_tokens,
424+
"stop": None,
425+
"ignore_eos": True,
426+
}
427+
)
428+
elif self.max_tokens is not None:
429+
arguments.body["max_completion_tokens"] = self.max_tokens
430+
431+
# Apply extra arguments
432+
if self.extras:
433+
arguments.model_combine(self.extras)
434+
435+
# Build messages
436+
arguments.body["messages"] = []
437+
438+
for prefix in data.columns.get("prefix_column", []):
439+
if not prefix:
440+
continue
441+
442+
arguments.body["messages"].append({"role": "system", "content": prefix})
443+
444+
for text in data.columns.get("text_column", []):
445+
if not text:
446+
continue
447+
448+
arguments.body["messages"].append(
449+
{"role": "user", "content": [{"type": "text", "text": text}]}
450+
)
451+
452+
for image in data.columns.get("image_column", []):
453+
if not image:
454+
continue
455+
456+
arguments.body["messages"].append(
457+
{
458+
"role": "user",
459+
"content": [{"type": "image_url", "image_url": image.get("image")}],
460+
}
461+
)
462+
463+
for video in data.columns.get("video_column", []):
464+
if not video:
465+
continue
466+
467+
arguments.body["messages"].append(
468+
{
469+
"role": "user",
470+
"content": [{"type": "video_url", "video_url": video.get("video")}],
471+
}
472+
)
473+
474+
for audio in data.columns.get("audio_column", []):
475+
if not audio:
476+
continue
477+
478+
arguments.body["messages"].append(
479+
{
480+
"role": "user",
481+
"content": [
482+
{
483+
"type": "input_audio",
484+
"input_audio": {
485+
"data": audio.get("audio"),
486+
"format": audio.get("format"),
487+
},
488+
}
489+
],
490+
}
491+
)
492+
493+
return arguments
494+
495+
def formatter_audio_transcriptions( # noqa: C901
496+
self, data: GenerationRequest
497+
) -> GenerationRequestArguments:
498+
arguments = GenerationRequestArguments(files={})
499+
arguments.body = {}
500+
501+
# Add model
502+
if self.model is not None:
503+
arguments.body["model"] = self.model
504+
505+
# Configure streaming
506+
if self.stream:
507+
arguments.stream = True
508+
arguments.body["stream"] = True
509+
arguments.body["stream_options"] = {"include_usage": True}
510+
511+
# Apply extra arguments
512+
if self.extras:
513+
arguments.model_combine(self.extras)
514+
515+
# Build audio input
516+
audio_columns = data.columns.get("audio_column", [])
517+
if len(audio_columns) != 1:
518+
raise ValueError(
519+
f"GenerativeAudioTranscriptionRequestFormatter expects exactly "
520+
f"one audio column, but got {len(audio_columns)}."
521+
)
522+
523+
arguments.files = {
524+
"file": (
525+
audio_columns[0].get("file_name", "audio_input"),
526+
audio_columns[0].get("audio"),
527+
audio_columns[0].get("mimetype"),
528+
)
529+
}
530+
531+
# Build prompt
532+
prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre)
533+
text = "".join(txt for txt in data.columns.get("text_column", []) if txt)
534+
if prefix or text:
535+
prompt = prefix + text
536+
arguments.body["prompt"] = prompt
537+
538+
return arguments

src/guidellm/data/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
DatasetDeserializer,
55
DatasetDeserializerFactory,
66
)
7+
from .finalizers import DatasetFinalizer, FinalizerRegistry
78
from .loaders import DataLoader, DatasetsIterator
89
from .preprocessors import (
910
DataDependentPreprocessor,
1011
DatasetPreprocessor,
1112
PreprocessorRegistry,
12-
RequestFormatter,
1313
)
1414
from .processor import ProcessorFactory
1515
from .schemas import GenerativeDatasetColumnType
@@ -20,11 +20,12 @@
2020
"DataNotSupportedError",
2121
"DatasetDeserializer",
2222
"DatasetDeserializerFactory",
23+
"DatasetFinalizer",
2324
"DatasetPreprocessor",
2425
"DatasetsIterator",
26+
"FinalizerRegistry",
2527
"GenerativeDatasetColumnType",
2628
"GenerativeRequestCollator",
2729
"PreprocessorRegistry",
2830
"ProcessorFactory",
29-
"RequestFormatter",
3031
]

0 commit comments

Comments
 (0)