Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions src/guidellm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,13 @@ def benchmark():
# Data configuration
@click.option(
"--request-type",
default=BenchmarkGenerativeTextArgs.get_default("data_request_formatter"),
default=BenchmarkGenerativeTextArgs.get_default("request_type"),
type=click.Choice(list(get_literal_vals(GenerativeRequestType))),
help=(
f"Request type to create for each data sample. "
f"Options: {', '.join(get_literal_vals(GenerativeRequestType))}."
),
)
@click.option(
"--request-formatter-kwargs",
default=None,
callback=cli_tools.parse_json,
help="JSON string of arguments to pass to the request formatter.",
)
@click.option(
"--processor",
default=BenchmarkGenerativeTextArgs.get_default("processor"),
Expand Down Expand Up @@ -223,10 +217,17 @@ def benchmark():
),
)
@click.option(
"--data-column-mapper",
default=BenchmarkGenerativeTextArgs.get_default("data_column_mapper"),
"--data-preprocessors",
default=BenchmarkGenerativeTextArgs.get_default("data_preprocessors"),
callback=cli_tools.parse_json,
multiple=True,
help="JSON string of preprocessors to apply to the dataset.",
)
@click.option(
"--data-finalizer",
default=BenchmarkGenerativeTextArgs.get_default("data_finalizer"),
callback=cli_tools.parse_json,
help="JSON string of column mappings to apply to the dataset.",
help="JSON string of finalizer to convert dataset rows to requests.",
)
@click.option(
"--data-sampler",
Expand Down Expand Up @@ -386,18 +387,6 @@ def run(**kwargs):
# Only set CLI args that differ from click defaults
kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)

# Handle remapping for request params
request_type = kwargs.pop("request_type", None)
request_formatter_kwargs = kwargs.pop("request_formatter_kwargs", None)
if request_type is not None:
kwargs["data_request_formatter"] = (
request_type
if not request_formatter_kwargs
else {"request_type": request_type, **request_formatter_kwargs}
)
elif request_formatter_kwargs is not None:
kwargs["data_request_formatter"] = request_formatter_kwargs

# Handle output path remapping
if (output_path := kwargs.pop("output_path", None)) is not None:
path = Path(output_path)
Expand Down
239 changes: 223 additions & 16 deletions src/guidellm/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@

import asyncio
import time
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from typing import Any

import httpx

from guidellm.backends.backend import Backend
from guidellm.backends.response_handlers import GenerationResponseHandlerFactory
from guidellm.schemas import GenerationRequest, GenerationResponse, RequestInfo
from guidellm.schemas import (
GenerationRequest,
GenerationRequestArguments,
GenerationResponse,
RequestInfo,
)

__all__ = ["OpenAIHTTPBackend"]

Expand Down Expand Up @@ -59,6 +64,10 @@ def __init__(
follow_redirects: bool = True,
verify: bool = False,
validate_backend: bool | str | dict[str, Any] = True,
stream: bool = True,
extras: dict[str, Any] | GenerationRequestArguments | None = None,
max_tokens: int | None = None,
max_completion_tokens: int | None = None,
):
"""
Initialize OpenAI HTTP backend with server configuration.
Expand Down Expand Up @@ -96,11 +105,28 @@ def __init__(
self.validate_backend: dict[str, Any] | None = self._resolve_validate_kwargs(
validate_backend
)
self.stream: bool = stream
self.extras = (
GenerationRequestArguments(**extras)
if extras and isinstance(extras, dict)
else extras
)
self.max_tokens: int | None = max_tokens or max_completion_tokens

# Runtime state
self._in_process = False
self._async_client: httpx.AsyncClient | None = None

# TODO: Find a better way to register formatters
self.request_formatters: dict[
str, Callable[[GenerationRequest], GenerationRequestArguments]
] = {
"text_completions": self.formatter_text_completions,
"chat_completions": self.formatter_chat_completions,
"audio_transcriptions": self.formatter_audio_transcriptions,
"audio_translations": self.formatter_audio_transcriptions,
}

@property
def info(self) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -227,49 +253,56 @@ async def resolve( # type: ignore[override]
if history is not None:
raise NotImplementedError("Multi-turn requests not yet supported")

arguments: GenerationRequestArguments = self.request_formatters[
request.request_type
](request)

if (request_path := self.api_routes.get(request.request_type)) is None:
raise ValueError(f"Unsupported request type '{request.request_type}'")

request_url = f"{self.target}/{request_path}"
request_files = (
{
key: tuple(value) if isinstance(value, list) else value
for key, value in request.arguments.files.items()
for key, value in arguments.files.items()
}
if request.arguments.files
if arguments.files
else None
)
request_json = request.arguments.body if not request_files else None
request_data = request.arguments.body if request_files else None
request_json = arguments.body if not request_files else None
request_data = arguments.body if request_files else None
response_handler = GenerationResponseHandlerFactory.create(
request.request_type, handler_overrides=self.response_handlers
)

if not request.arguments.stream:
if not arguments.stream:
request_info.timings.request_start = time.time()
response = await self._async_client.request(
request.arguments.method or "POST",
arguments.method or "POST",
request_url,
params=request.arguments.params,
headers=request.arguments.headers,
params=arguments.params,
headers=arguments.headers,
json=request_json,
data=request_data,
files=request_files,
)
request_info.timings.request_end = time.time()
response.raise_for_status()
data = response.json()
yield response_handler.compile_non_streaming(request, data), request_info
yield (
response_handler.compile_non_streaming(request, arguments, data),
request_info,
)
return

try:
request_info.timings.request_start = time.time()

async with self._async_client.stream(
request.arguments.method or "POST",
arguments.method or "POST",
request_url,
params=request.arguments.params,
headers=request.arguments.headers,
params=arguments.params,
headers=arguments.headers,
json=request_json,
data=request_data,
files=request_files,
Expand Down Expand Up @@ -298,10 +331,10 @@ async def resolve( # type: ignore[override]
request_info.timings.token_iterations += iterations

request_info.timings.request_end = time.time()
yield response_handler.compile_streaming(request), request_info
yield response_handler.compile_streaming(request, arguments), request_info
except asyncio.CancelledError as err:
# Yield current result to store iterative results before propagating
yield response_handler.compile_streaming(request), request_info
yield response_handler.compile_streaming(request, arguments), request_info
raise err

def _resolve_validate_kwargs(
Expand Down Expand Up @@ -332,3 +365,177 @@ def _resolve_validate_kwargs(
validate_kwargs["method"] = "GET"

return validate_kwargs

def formatter_text_completions(
self, data: GenerationRequest
) -> GenerationRequestArguments:
arguments: GenerationRequestArguments = GenerationRequestArguments()
arguments.body = {} # The type checker works better setting this field here

# Add model
if self.model is not None:
arguments.body["model"] = self.model

# Configure streaming
if self.stream:
arguments.stream = True
arguments.body["stream"] = True
arguments.body["stream_options"] = {"include_usage": True}

# Handle output tokens
if data.output_metrics.text_tokens:
arguments.body["max_tokens"] = data.output_metrics.text_tokens
arguments.body["stop"] = None
arguments.body["ignore_eos"] = True
elif self.max_tokens is not None:
arguments.body["max_tokens"] = self.max_tokens

# Apply extra arguments
if self.extras:
arguments.model_combine(self.extras)

# Build prompt
prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre)
text = "".join(txt for txt in data.columns.get("text_column", []) if txt)
if prefix or text:
prompt = prefix + text
arguments.body["prompt"] = prompt

return arguments

def formatter_chat_completions( # noqa: C901, PLR0912, PLR0915
self, data: GenerationRequest
) -> GenerationRequestArguments:
arguments = GenerationRequestArguments()
arguments.body = {} # The type checker works best with body assigned here

# Add model
if self.model is not None:
arguments.body["model"] = self.model

# Configure streaming
if self.stream:
arguments.stream = True
arguments.body["stream"] = True
arguments.body["stream_options"] = {"include_usage": True}

# Handle output tokens
if data.output_metrics.text_tokens:
arguments.body.update(
{
"max_completion_tokens": data.output_metrics.text_tokens,
"stop": None,
"ignore_eos": True,
}
)
elif self.max_tokens is not None:
arguments.body["max_completion_tokens"] = self.max_tokens

# Apply extra arguments
if self.extras:
arguments.model_combine(self.extras)

# Build messages
arguments.body["messages"] = []

for prefix in data.columns.get("prefix_column", []):
if not prefix:
continue

arguments.body["messages"].append({"role": "system", "content": prefix})

for text in data.columns.get("text_column", []):
if not text:
continue

arguments.body["messages"].append(
{"role": "user", "content": [{"type": "text", "text": text}]}
)

for image in data.columns.get("image_column", []):
if not image:
continue

arguments.body["messages"].append(
{
"role": "user",
"content": [{"type": "image_url", "image_url": image.get("image")}],
}
)

for video in data.columns.get("video_column", []):
if not video:
continue

arguments.body["messages"].append(
{
"role": "user",
"content": [{"type": "video_url", "video_url": video.get("video")}],
}
)

for audio in data.columns.get("audio_column", []):
if not audio:
continue

arguments.body["messages"].append(
{
"role": "user",
"content": [
{
"type": "input_audio",
"input_audio": {
"data": audio.get("audio"),
"format": audio.get("format"),
},
}
],
}
)

return arguments

def formatter_audio_transcriptions( # noqa: C901
self, data: GenerationRequest
) -> GenerationRequestArguments:
arguments = GenerationRequestArguments(files={})
arguments.body = {}

# Add model
if self.model is not None:
arguments.body["model"] = self.model

# Configure streaming
if self.stream:
arguments.stream = True
arguments.body["stream"] = True
arguments.body["stream_options"] = {"include_usage": True}

# Apply extra arguments
if self.extras:
arguments.model_combine(self.extras)

# Build audio input
audio_columns = data.columns.get("audio_column", [])
if len(audio_columns) != 1:
raise ValueError(
f"GenerativeAudioTranscriptionRequestFormatter expects exactly "
f"one audio column, but got {len(audio_columns)}."
)

arguments.files = {
"file": (
audio_columns[0].get("file_name", "audio_input"),
audio_columns[0].get("audio"),
audio_columns[0].get("mimetype"),
)
}

# Build prompt
prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre)
text = "".join(txt for txt in data.columns.get("text_column", []) if txt)
if prefix or text:
prompt = prefix + text
arguments.body["prompt"] = prompt

return arguments
Loading
Loading