diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 2b52bbc5b..e4ef58557 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -178,19 +178,13 @@ def benchmark(): # Data configuration @click.option( "--request-type", - default=BenchmarkGenerativeTextArgs.get_default("data_request_formatter"), + default=BenchmarkGenerativeTextArgs.get_default("request_type"), type=click.Choice(list(get_literal_vals(GenerativeRequestType))), help=( f"Request type to create for each data sample. " f"Options: {', '.join(get_literal_vals(GenerativeRequestType))}." ), ) -@click.option( - "--request-formatter-kwargs", - default=None, - callback=cli_tools.parse_json, - help="JSON string of arguments to pass to the request formatter.", -) @click.option( "--processor", default=BenchmarkGenerativeTextArgs.get_default("processor"), @@ -223,10 +217,17 @@ def benchmark(): ), ) @click.option( - "--data-column-mapper", - default=BenchmarkGenerativeTextArgs.get_default("data_column_mapper"), + "--data-preprocessors", + default=BenchmarkGenerativeTextArgs.get_default("data_preprocessors"), + callback=cli_tools.parse_json, + multiple=True, + help="JSON string of preprocessors to apply to the dataset.", +) +@click.option( + "--data-finalizer", + default=BenchmarkGenerativeTextArgs.get_default("data_finalizer"), callback=cli_tools.parse_json, - help="JSON string of column mappings to apply to the dataset.", + help="JSON string of finalizer to convert dataset rows to requests.", ) @click.option( "--data-sampler", @@ -386,18 +387,6 @@ def run(**kwargs): # Only set CLI args that differ from click defaults kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs) - # Handle remapping for request params - request_type = kwargs.pop("request_type", None) - request_formatter_kwargs = kwargs.pop("request_formatter_kwargs", None) - if request_type is not None: - kwargs["data_request_formatter"] = ( - request_type - if not request_formatter_kwargs - else {"request_type": request_type, **request_formatter_kwargs} - ) - elif request_formatter_kwargs is not None: - kwargs["data_request_formatter"] = request_formatter_kwargs - # Handle output path remapping if (output_path := kwargs.pop("output_path", None)) is not None: path = Path(output_path) diff --git a/src/guidellm/backends/openai.py b/src/guidellm/backends/openai.py index 57e2d95a6..03c319769 100644 --- a/src/guidellm/backends/openai.py +++ b/src/guidellm/backends/openai.py @@ -12,14 +12,19 @@ import asyncio import time -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from typing import Any import httpx from guidellm.backends.backend import Backend from guidellm.backends.response_handlers import GenerationResponseHandlerFactory -from guidellm.schemas import GenerationRequest, GenerationResponse, RequestInfo +from guidellm.schemas import ( + GenerationRequest, + GenerationRequestArguments, + GenerationResponse, + RequestInfo, +) __all__ = ["OpenAIHTTPBackend"] @@ -59,6 +64,10 @@ def __init__( follow_redirects: bool = True, verify: bool = False, validate_backend: bool | str | dict[str, Any] = True, + stream: bool = True, + extras: dict[str, Any] | GenerationRequestArguments | None = None, + max_tokens: int | None = None, + max_completion_tokens: int | None = None, ): """ Initialize OpenAI HTTP backend with server configuration. @@ -96,11 +105,28 @@ def __init__( self.validate_backend: dict[str, Any] | None = self._resolve_validate_kwargs( validate_backend ) + self.stream: bool = stream + self.extras = ( + GenerationRequestArguments(**extras) + if extras and isinstance(extras, dict) + else extras + ) + self.max_tokens: int | None = max_tokens or max_completion_tokens # Runtime state self._in_process = False self._async_client: httpx.AsyncClient | None = None + # TODO: Find a better way to register formatters + self.request_formatters: dict[ + str, Callable[[GenerationRequest], GenerationRequestArguments] + ] = { + "text_completions": self.formatter_text_completions, + "chat_completions": self.formatter_chat_completions, + "audio_transcriptions": self.formatter_audio_transcriptions, + "audio_translations": self.formatter_audio_transcriptions, + } + @property def info(self) -> dict[str, Any]: """ @@ -227,6 +253,10 @@ async def resolve( # type: ignore[override] if history is not None: raise NotImplementedError("Multi-turn requests not yet supported") + arguments: GenerationRequestArguments = self.request_formatters[ + request.request_type + ](request) + if (request_path := self.api_routes.get(request.request_type)) is None: raise ValueError(f"Unsupported request type '{request.request_type}'") @@ -234,24 +264,24 @@ async def resolve( # type: ignore[override] request_files = ( { key: tuple(value) if isinstance(value, list) else value - for key, value in request.arguments.files.items() + for key, value in arguments.files.items() } - if request.arguments.files + if arguments.files else None ) - request_json = request.arguments.body if not request_files else None - request_data = request.arguments.body if request_files else None + request_json = arguments.body if not request_files else None + request_data = arguments.body if request_files else None response_handler = GenerationResponseHandlerFactory.create( request.request_type, handler_overrides=self.response_handlers ) - if not request.arguments.stream: + if not arguments.stream: request_info.timings.request_start = time.time() response = await self._async_client.request( - request.arguments.method or "POST", + arguments.method or "POST", request_url, - params=request.arguments.params, - headers=request.arguments.headers, + params=arguments.params, + headers=arguments.headers, json=request_json, data=request_data, files=request_files, @@ -259,17 +289,20 @@ async def resolve( # type: ignore[override] request_info.timings.request_end = time.time() response.raise_for_status() data = response.json() - yield response_handler.compile_non_streaming(request, data), request_info + yield ( + response_handler.compile_non_streaming(request, arguments, data), + request_info, + ) return try: request_info.timings.request_start = time.time() async with self._async_client.stream( - request.arguments.method or "POST", + arguments.method or "POST", request_url, - params=request.arguments.params, - headers=request.arguments.headers, + params=arguments.params, + headers=arguments.headers, json=request_json, data=request_data, files=request_files, @@ -298,10 +331,10 @@ async def resolve( # type: ignore[override] request_info.timings.token_iterations += iterations request_info.timings.request_end = time.time() - yield response_handler.compile_streaming(request), request_info + yield response_handler.compile_streaming(request, arguments), request_info except asyncio.CancelledError as err: # Yield current result to store iterative results before propagating - yield response_handler.compile_streaming(request), request_info + yield response_handler.compile_streaming(request, arguments), request_info raise err def _resolve_validate_kwargs( @@ -332,3 +365,177 @@ def _resolve_validate_kwargs( validate_kwargs["method"] = "GET" return validate_kwargs + + def formatter_text_completions( + self, data: GenerationRequest + ) -> GenerationRequestArguments: + arguments: GenerationRequestArguments = GenerationRequestArguments() + arguments.body = {} # The type checker works better setting this field here + + # Add model + if self.model is not None: + arguments.body["model"] = self.model + + # Configure streaming + if self.stream: + arguments.stream = True + arguments.body["stream"] = True + arguments.body["stream_options"] = {"include_usage": True} + + # Handle output tokens + if data.output_metrics.text_tokens: + arguments.body["max_tokens"] = data.output_metrics.text_tokens + arguments.body["stop"] = None + arguments.body["ignore_eos"] = True + elif self.max_tokens is not None: + arguments.body["max_tokens"] = self.max_tokens + + # Apply extra arguments + if self.extras: + arguments.model_combine(self.extras) + + # Build prompt + prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre) + text = "".join(txt for txt in data.columns.get("text_column", []) if txt) + if prefix or text: + prompt = prefix + text + arguments.body["prompt"] = prompt + + return arguments + + def formatter_chat_completions( # noqa: C901, PLR0912, PLR0915 + self, data: GenerationRequest + ) -> GenerationRequestArguments: + arguments = GenerationRequestArguments() + arguments.body = {} # The type checker works best with body assigned here + + # Add model + if self.model is not None: + arguments.body["model"] = self.model + + # Configure streaming + if self.stream: + arguments.stream = True + arguments.body["stream"] = True + arguments.body["stream_options"] = {"include_usage": True} + + # Handle output tokens + if data.output_metrics.text_tokens: + arguments.body.update( + { + "max_completion_tokens": data.output_metrics.text_tokens, + "stop": None, + "ignore_eos": True, + } + ) + elif self.max_tokens is not None: + arguments.body["max_completion_tokens"] = self.max_tokens + + # Apply extra arguments + if self.extras: + arguments.model_combine(self.extras) + + # Build messages + arguments.body["messages"] = [] + + for prefix in data.columns.get("prefix_column", []): + if not prefix: + continue + + arguments.body["messages"].append({"role": "system", "content": prefix}) + + for text in data.columns.get("text_column", []): + if not text: + continue + + arguments.body["messages"].append( + {"role": "user", "content": [{"type": "text", "text": text}]} + ) + + for image in data.columns.get("image_column", []): + if not image: + continue + + arguments.body["messages"].append( + { + "role": "user", + "content": [{"type": "image_url", "image_url": image.get("image")}], + } + ) + + for video in data.columns.get("video_column", []): + if not video: + continue + + arguments.body["messages"].append( + { + "role": "user", + "content": [{"type": "video_url", "video_url": video.get("video")}], + } + ) + + for audio in data.columns.get("audio_column", []): + if not audio: + continue + + arguments.body["messages"].append( + { + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": audio.get("audio"), + "format": audio.get("format"), + }, + } + ], + } + ) + + return arguments + + def formatter_audio_transcriptions( # noqa: C901 + self, data: GenerationRequest + ) -> GenerationRequestArguments: + arguments = GenerationRequestArguments(files={}) + arguments.body = {} + + # Add model + if self.model is not None: + arguments.body["model"] = self.model + + # Configure streaming + if self.stream: + arguments.stream = True + arguments.body["stream"] = True + arguments.body["stream_options"] = {"include_usage": True} + + # Apply extra arguments + if self.extras: + arguments.model_combine(self.extras) + + # Build audio input + audio_columns = data.columns.get("audio_column", []) + if len(audio_columns) != 1: + raise ValueError( + f"GenerativeAudioTranscriptionRequestFormatter expects exactly " + f"one audio column, but got {len(audio_columns)}." + ) + + arguments.files = { + "file": ( + audio_columns[0].get("file_name", "audio_input"), + audio_columns[0].get("audio"), + audio_columns[0].get("mimetype"), + ) + } + + # Build prompt + prefix = "".join(pre for pre in data.columns.get("prefix_column", []) if pre) + text = "".join(txt for txt in data.columns.get("text_column", []) if txt) + if prefix or text: + prompt = prefix + text + arguments.body["prompt"] = prompt + + return arguments diff --git a/src/guidellm/backends/response_handlers.py b/src/guidellm/backends/response_handlers.py index e8087e058..e4e2baa38 100644 --- a/src/guidellm/backends/response_handlers.py +++ b/src/guidellm/backends/response_handlers.py @@ -12,6 +12,7 @@ from typing import Any, Protocol, cast from guidellm.schemas import GenerationRequest, GenerationResponse, UsageMetrics +from guidellm.schemas.request import GenerationRequestArguments from guidellm.utils import RegistryMixin, json __all__ = [ @@ -33,7 +34,10 @@ class GenerationResponseHandler(Protocol): """ def compile_non_streaming( - self, request: GenerationRequest, response: Any + self, + request: GenerationRequest, + arguments: GenerationRequestArguments, + response: Any, ) -> GenerationResponse: """ Process a complete non-streaming API response. @@ -53,7 +57,9 @@ def add_streaming_line(self, line: str) -> int | None: """ ... - def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + def compile_streaming( + self, request: GenerationRequest, arguments: GenerationRequestArguments + ) -> GenerationResponse: """ Compile accumulated streaming data into a final response. @@ -127,7 +133,10 @@ def __init__(self): self.streaming_response_id: str | None = None def compile_non_streaming( - self, request: GenerationRequest, response: dict + self, + request: GenerationRequest, + arguments: GenerationRequestArguments, + response: dict, ) -> GenerationResponse: """ Process a complete text completion response. @@ -143,9 +152,7 @@ def compile_non_streaming( return GenerationResponse( request_id=request.request_id, - request_args=str( - request.arguments.model_dump() if request.arguments else None - ), + request_args=arguments.model_dump_json(), response_id=response.get("id"), # use vLLM ID if available text=text, input_metrics=input_metrics, @@ -181,7 +188,9 @@ def add_streaming_line(self, line: str) -> int | None: return 1 if updated else 0 - def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + def compile_streaming( + self, request: GenerationRequest, arguments: GenerationRequestArguments + ) -> GenerationResponse: """ Compile accumulated streaming text chunks into a final response. @@ -193,9 +202,7 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: return GenerationResponse( request_id=request.request_id, - request_args=str( - request.arguments.model_dump() if request.arguments else None - ), + request_args=arguments.model_dump_json(), response_id=self.streaming_response_id, # use vLLM ID if available text=text, input_metrics=input_metrics, @@ -290,7 +297,10 @@ class ChatCompletionsResponseHandler(TextCompletionsResponseHandler): """ def compile_non_streaming( - self, request: GenerationRequest, response: dict + self, + request: GenerationRequest, + arguments: GenerationRequestArguments, + response: dict, ) -> GenerationResponse: """ Process a complete chat completion response. @@ -309,9 +319,7 @@ def compile_non_streaming( return GenerationResponse( request_id=request.request_id, - request_args=str( - request.arguments.model_dump() if request.arguments else None - ), + request_args=arguments.model_dump_json(), response_id=response.get("id"), # use vLLM ID if available text=text, input_metrics=input_metrics, @@ -347,7 +355,9 @@ def add_streaming_line(self, line: str) -> int | None: return 1 if updated else 0 - def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + def compile_streaming( + self, request: GenerationRequest, arguments: GenerationRequestArguments + ) -> GenerationResponse: """ Compile accumulated streaming chat completion content into a final response. @@ -359,9 +369,7 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: return GenerationResponse( request_id=request.request_id, - request_args=str( - request.arguments.model_dump() if request.arguments else None - ), + request_args=arguments.model_dump_json(), response_id=self.streaming_response_id, # use vLLM ID if available text=text, input_metrics=input_metrics, @@ -399,7 +407,10 @@ def __init__(self): self.streaming_response_id: str | None = None def compile_non_streaming( - self, request: GenerationRequest, response: dict + self, + request: GenerationRequest, + arguments: GenerationRequestArguments, + response: dict, ) -> GenerationResponse: """ Process a complete audio transcription or translation response. @@ -417,9 +428,7 @@ def compile_non_streaming( return GenerationResponse( request_id=request.request_id, - request_args=str( - request.arguments.model_dump() if request.arguments else None - ), + request_args=arguments.model_dump_json(), response_id=response.get("id"), # use vLLM ID if available text=text, input_metrics=input_metrics, @@ -457,7 +466,9 @@ def add_streaming_line(self, line: str) -> int | None: return 1 if updated else 0 - def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: + def compile_streaming( + self, request: GenerationRequest, arguments: GenerationRequestArguments + ) -> GenerationResponse: """ Compile accumulated streaming audio text into a final response. @@ -469,9 +480,7 @@ def compile_streaming(self, request: GenerationRequest) -> GenerationResponse: return GenerationResponse( request_id=request.request_id, - request_args=str( - request.arguments.model_dump() if request.arguments else None - ), + request_args=arguments.model_dump_json(), response_id=self.streaming_response_id, text=text, input_metrics=input_metrics, diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 5b57b22fe..feca09ac3 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -13,7 +13,7 @@ from collections.abc import Callable, Mapping, MutableMapping from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, TypeVar from torch.utils.data import Sampler from transformers import PreTrainedTokenizerBase @@ -36,21 +36,21 @@ from guidellm.benchmark.schemas.base import TransientPhaseConfig from guidellm.data import ( DataLoader, + DatasetFinalizer, DatasetPreprocessor, + FinalizerRegistry, GenerativeRequestCollator, PreprocessorRegistry, ProcessorFactory, - RequestFormatter, ) -from guidellm.data.preprocessors import GenerativeColumnMapper from guidellm.scheduler import ( ConstraintInitializer, NonDistributedEnvironment, StrategyType, ) from guidellm.schemas import GenerationRequest, GenerationResponse -from guidellm.settings import settings from guidellm.utils import Console, InfoMixin +from guidellm.utils.registry import RegistryMixin __all__ = [ "benchmark_generative_text", @@ -178,19 +178,66 @@ async def resolve_processor( return processor +BaseTypeT = TypeVar("BaseTypeT") + + +def resolve_item_from_registry( + base_type: type[BaseTypeT], + registry: type[RegistryMixin], + item: Any, + extras: dict[str, Any] | None = None, +) -> BaseTypeT: + """ + Resolve an item from a registry, instantiating it if necessary. + + :param base_type: The expected base type of the item + :param item: The item to resolve, either an instance or a string identifier + :param registry: The registry to use for resolving string identifiers + :return: The resolved item as an instance of the base type + :raises ValueError: If the item cannot be resolved from the registry + :raises TypeError: If the resolved item is not of the expected base type + """ + if isinstance(item, base_type): + return item + else: + if isinstance(item, str): + item_type = item + kwargs = {} + else: + item_dict = dict(item) + item_type = item_dict.pop("type", None) + if item_type is None: + raise ValueError( + f"Item dictionary must contain a 'type' key to resolve from " + f"{registry.__class__.__name__}." + ) + kwargs = item_dict + + if (item_class := registry.get_registered_object(item_type)) is None: + raise ValueError( + f"Item type '{item_type}' is not registered in the " + f"{registry.__class__.__name__}." + ) + if not issubclass(item_class, base_type): + raise TypeError( + f"Resolved item type '{item_type}' is not a subclass of " + f"{base_type.__name__}." + ) + if extras: + kwargs.update(extras) + return item_class(**kwargs) + + async def resolve_request_loader( data: list[Any], model: str, + request_type: str, data_args: list[dict[str, Any]] | None, data_samples: int, processor: ProcessorInputT | None, processor_args: dict[str, Any] | None, - data_column_mapper: ( - DatasetPreprocessor - | dict[str, str | list[str]] - | Literal["generative_column_mapper"] - ), - data_request_formatter: (RequestFormatter | dict[str, str] | str), + data_preprocessors: list[DatasetPreprocessor | dict[str, str | list[str]] | str], + data_finalizer: (DatasetFinalizer | dict[str, Any] | str), data_collator: Callable | Literal["generative"] | None, data_sampler: Sampler[int] | Literal["shuffle"] | None, data_num_workers: int | None, @@ -232,54 +279,22 @@ async def resolve_request_loader( else None ) - data_column_mapper_instance: DatasetPreprocessor - if isinstance(data_column_mapper, DatasetPreprocessor): - data_column_mapper_instance = data_column_mapper - else: - column_mappings = ( - data_column_mapper if isinstance(data_column_mapper, dict) else None - ) - data_column_mapper_instance = GenerativeColumnMapper( - column_mappings=column_mappings # type: ignore[arg-type] - ) - - data_request_formatter_instance: RequestFormatter - if isinstance(data_request_formatter, RequestFormatter): - data_request_formatter_instance = data_request_formatter - else: - if isinstance(data_request_formatter, str): - request_type = data_request_formatter - formatter_kwargs: dict[str, Any] = {} - else: - # Extract request_type from formatter dictionary - formatter_dict = dict(data_request_formatter) - request_type = formatter_dict.pop("request_type", settings.preferred_route) - formatter_kwargs = formatter_dict - - if ( - formatter_class := PreprocessorRegistry.get_registered_object(request_type) - ) is None: - raise ValueError( - f"Request formatter '{request_type}' is not registered in the " - f"PreprocessorRegistry." - ) - if not issubclass(formatter_class, RequestFormatter): - raise TypeError( - f"Request formatter '{request_type}' is not a subclass of " - f"RequestFormatter." - ) - - data_request_formatter_instance = formatter_class( - model=model, - **formatter_kwargs, - ) - - # Cast to proper types for the DataLoader preprocessors list preprocessors_list: list[DatasetPreprocessor] = [ - data_column_mapper_instance, - data_request_formatter_instance, + resolve_item_from_registry( + DatasetPreprocessor, # type: ignore [type-abstract] + PreprocessorRegistry, + preprocessor, + ) + for preprocessor in data_preprocessors ] + finalizer_instance = resolve_item_from_registry( + DatasetFinalizer, # type: ignore [type-abstract] + FinalizerRegistry, + data_finalizer, + extras={"request_type": request_type}, + ) + request_loader: DataLoader[GenerationRequest] = DataLoader( data=data, data_args=data_args, @@ -289,6 +304,7 @@ async def resolve_request_loader( processor_args=processor_args, ), preprocessors=preprocessors_list, + finalizer=finalizer_instance, collator=( data_collator if callable(data_collator) else GenerativeRequestCollator() ), @@ -460,12 +476,13 @@ async def benchmark_generative_text( request_loader = await resolve_request_loader( data=args.data, model=model, + request_type=args.request_type, data_args=args.data_args, data_samples=args.data_samples, processor=processor, processor_args=args.processor_args, - data_column_mapper=args.data_column_mapper, - data_request_formatter=args.data_request_formatter, + data_preprocessors=args.data_preprocessors, + data_finalizer=args.data_finalizer, data_collator=args.data_collator, data_sampler=args.data_sampler, data_num_workers=args.data_num_workers, diff --git a/src/guidellm/benchmark/schemas/generative/accumulator.py b/src/guidellm/benchmark/schemas/generative/accumulator.py index a2c9f8948..57cc953bc 100644 --- a/src/guidellm/benchmark/schemas/generative/accumulator.py +++ b/src/guidellm/benchmark/schemas/generative/accumulator.py @@ -722,8 +722,9 @@ def compile_stats( ) if response is None: + # FIXME: request_args is wrong response = GenerationResponse( - request_id=info.request_id, request_args=str(first_request.arguments) + request_id=info.request_id, request_args=str(first_request.columns) ) return response.compile_stats( diff --git a/src/guidellm/benchmark/schemas/generative/entrypoints.py b/src/guidellm/benchmark/schemas/generative/entrypoints.py index a080daa03..1b10caeec 100644 --- a/src/guidellm/benchmark/schemas/generative/entrypoints.py +++ b/src/guidellm/benchmark/schemas/generative/entrypoints.py @@ -35,9 +35,10 @@ from guidellm.benchmark.profiles import Profile, ProfileType from guidellm.benchmark.scenarios import get_builtin_scenarios from guidellm.benchmark.schemas.base import TransientPhaseConfig -from guidellm.data import DatasetPreprocessor, RequestFormatter +from guidellm.data import DatasetFinalizer, DatasetPreprocessor from guidellm.scheduler import StrategyType from guidellm.schemas import StandardBaseModel +from guidellm.settings import settings __all__ = ["BenchmarkGenerativeTextArgs"] @@ -179,6 +180,13 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: backend_kwargs: dict[str, Any] | None = Field( default=None, description="Additional backend configuration arguments" ) + request_type: str = Field( + default_factory=lambda: settings.preferred_route, + description=( + "Request type for backend operations;" + " shorthand for backend_kwargs['request_type']" + ), + ) model: str | None = Field(default=None, description="Model identifier for backend") # Data configuration processor: str | Path | PreTrainedTokenizerBase | None = Field( @@ -194,23 +202,21 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: data_samples: int = Field( default=-1, description="Number of samples to use from datasets (-1 for all)" ) - data_column_mapper: ( - DatasetPreprocessor - | dict[str, str | list[str]] - | Literal["generative_column_mapper"] - ) = Field( - default="generative_column_mapper", - description="Column mapping preprocessor for dataset fields", + # TODO: Make it easy to cutomize preprocessors without editing the full list + data_preprocessors: list[DatasetPreprocessor | dict[str, str | list[str]] | str] = ( + Field( + default_factory=lambda: [ # type: ignore [arg-type] + "generative_column_mapper", + "encode_audio", + "encode_image", + "encode_video", + ], + description="List of dataset preprocessors to apply in order", + ) ) - data_request_formatter: RequestFormatter | dict[str, Any] | str = Field( - default="chat_completions", - description="Request formatting preprocessor or template name", - validation_alias=AliasChoices( - "data_request_formatter", - "data-request-formatter", - "request_type", - "request-type", - ), + data_finalizer: DatasetFinalizer | str | dict[str, Any] = Field( + default="generative", + description="Finalizer for preparing data samples into requests", ) data_collator: Callable | Literal["generative"] | None = Field( default="generative", description="Data collator for batch processing" @@ -284,7 +290,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: default=None, description="Maximum global error rate (0-1) before stopping" ) - @field_validator("data", "data_args", "rate", mode="wrap") + @field_validator("data", "data_args", "rate", "data_preprocessors", mode="wrap") @classmethod def single_to_list( cls, value: Any, handler: ValidatorFunctionWrapHandler @@ -323,28 +329,25 @@ def serialize_data_collator( """Serialize data_collator to string or None.""" return data_collator if isinstance(data_collator, str) else None - @field_serializer("data_column_mapper") + @field_serializer("data_preprocessors") def serialize_data_column_mapper( self, - data_column_mapper: ( - DatasetPreprocessor - | dict[str, str | list[str]] - | Literal["generative_column_mapper"] - ), - ) -> dict | str: + data_preprocessors: list[ + DatasetPreprocessor | dict[str, str | list[str]] | str + ], + ) -> list[dict | str]: """Serialize data_column_mapper to dict or string.""" - return data_column_mapper if isinstance(data_column_mapper, dict | str) else {} + return [ + (preprocessor if isinstance(preprocessor, dict | str) else {}) + for preprocessor in data_preprocessors + ] - @field_serializer("data_request_formatter") + @field_serializer("data_finalizer") def serialize_data_request_formatter( - self, data_request_formatter: RequestFormatter | dict[str, Any] | str + self, data_finalizer: DatasetFinalizer | dict[str, Any] | str ) -> dict | str: """Serialize data_request_formatter to dict or string.""" - return ( - data_request_formatter - if isinstance(data_request_formatter, dict | str) - else {} - ) + return data_finalizer if isinstance(data_finalizer, dict | str) else {} @field_serializer("data_sampler") def serialize_data_sampler( diff --git a/src/guidellm/data/__init__.py b/src/guidellm/data/__init__.py index 9adbd3c8d..eb69e1d83 100644 --- a/src/guidellm/data/__init__.py +++ b/src/guidellm/data/__init__.py @@ -4,12 +4,12 @@ DatasetDeserializer, DatasetDeserializerFactory, ) +from .finalizers import DatasetFinalizer, FinalizerRegistry from .loaders import DataLoader, DatasetsIterator from .preprocessors import ( DataDependentPreprocessor, DatasetPreprocessor, PreprocessorRegistry, - RequestFormatter, ) from .processor import ProcessorFactory from .schemas import GenerativeDatasetColumnType @@ -20,11 +20,12 @@ "DataNotSupportedError", "DatasetDeserializer", "DatasetDeserializerFactory", + "DatasetFinalizer", "DatasetPreprocessor", "DatasetsIterator", + "FinalizerRegistry", "GenerativeDatasetColumnType", "GenerativeRequestCollator", "PreprocessorRegistry", "ProcessorFactory", - "RequestFormatter", ] diff --git a/src/guidellm/data/finalizers.py b/src/guidellm/data/finalizers.py new file mode 100644 index 000000000..37586ec10 --- /dev/null +++ b/src/guidellm/data/finalizers.py @@ -0,0 +1,121 @@ +from typing import Any, Protocol, TypeVar, runtime_checkable + +from guidellm.schemas.request import GenerationRequest, UsageMetrics +from guidellm.utils.registry import RegistryMixin + +DataT_co = TypeVar("DataT_co", covariant=True) + + +@runtime_checkable +class DatasetFinalizer(Protocol[DataT_co]): + """ + Protocol for finalizing dataset rows into a desired data type. + """ + + def __call__(self, item: dict[str, Any]) -> DataT_co: ... + + +class FinalizerRegistry(RegistryMixin[type[DatasetFinalizer]]): + pass + + +@FinalizerRegistry.register("generative") +class GenerativeRequestFinalizer(DatasetFinalizer[GenerationRequest]): + """ + Finalizer that converts dataset rows into GenerationRequest objects, + aggregating usage metrics from the provided columns. + """ + + def __init__(self, request_type: str) -> None: + # TODO: Drop this in favor stting on the backend + self.request_type = request_type + + def __call__( # noqa: C901 PLR0912 + self, columns: dict[str, Any] + ) -> GenerationRequest: + input_metrics = UsageMetrics() + output_metrics = UsageMetrics() + + # Sum prompt token column + if prompt_tokens := sum( + count for count in columns.get("prompt_tokens_count_column", []) if count + ): + input_metrics.text_tokens = prompt_tokens + + # Sum output token column + if output_tokens := sum( + count for count in columns.get("output_tokens_count_column", []) if count + ): + output_metrics.text_tokens = output_tokens + + # Count words in prefixes + for prefix in columns.get("prefix_column", []): + if not prefix: + continue + + input_metrics.add_text_metrics(prefix) + + # Count words in text prompts + for text in columns.get("text_column", []): + if not text: + continue + + input_metrics.add_text_metrics(text) + + # Count pixels and bytes in images + for image in columns.get("image_column", []): + if not image: + continue + + if (image_pixels := image.get("image_pixels")) is not None: + input_metrics.image_pixels = ( + input_metrics.image_pixels or 0 + ) + image_pixels + if (image_bytes := image.get("image_bytes")) is not None: + input_metrics.image_bytes = ( + input_metrics.image_bytes or 0 + ) + image_bytes + + # Count frames, seconds, and bytes in videos + for video in columns.get("video_column", []): + if not video: + continue + + if (video_frames := video.get("video_frames")) is not None: + input_metrics.video_frames = ( + input_metrics.video_frames or 0 + ) + video_frames + if (video_seconds := video.get("video_seconds")) is not None: + input_metrics.video_seconds = ( + input_metrics.video_seconds or 0.0 + ) + video_seconds + if (video_bytes := video.get("video_bytes")) is not None: + input_metrics.video_bytes = ( + input_metrics.video_bytes or 0 + ) + video_bytes + + # Count samples, seconds, and bytes in audio + for audio in columns.get("audio_column", []): + if not audio: + continue + + if (audio_samples := audio.get("audio_samples")) is not None: + input_metrics.audio_samples = ( + input_metrics.audio_samples or 0 + ) + audio_samples + if (audio_seconds := audio.get("audio_seconds")) is not None: + input_metrics.audio_seconds = ( + input_metrics.audio_seconds or 0.0 + ) + audio_seconds + if (audio_bytes := audio.get("audio_bytes")) is not None: + input_metrics.audio_bytes = ( + input_metrics.audio_bytes or 0 + ) + audio_bytes + + # TODO: Filter columns to only those needed for the request + return GenerationRequest( + request_type=self.request_type, + columns=columns, + input_metrics=input_metrics, + output_metrics=output_metrics, + ) diff --git a/src/guidellm/data/loaders.py b/src/guidellm/data/loaders.py index e6393707e..a51d0c084 100644 --- a/src/guidellm/data/loaders.py +++ b/src/guidellm/data/loaders.py @@ -11,6 +11,7 @@ from transformers import PreTrainedTokenizerBase from guidellm.data.deserializers import DatasetDeserializerFactory +from guidellm.data.finalizers import DatasetFinalizer from guidellm.data.preprocessors import DataDependentPreprocessor, DatasetPreprocessor from guidellm.logger import logger from guidellm.utils import InfoMixin @@ -29,6 +30,7 @@ def __init__( data_samples: int, processor_factory: Callable[[], PreTrainedTokenizerBase], preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor], + finalizer: DatasetFinalizer[DataT], random_seed: int, ): if not data or not isinstance(data, list): @@ -60,6 +62,7 @@ def __init__( datasets=self.datasets, data_args=data_args, ) + self.finalizer = finalizer self.precache: list[Any] | None = ( list(self.generator(data_samples)) if data_samples else None ) @@ -113,12 +116,11 @@ def generator( ): continue + # Apply preprocessors in sequence for preprocessor in self.preprocessors: - # This can assign a GenerationRequest, which would then be - # passed into the preprocessor, which is a type violation. - # This should be fixed at some point. - row = preprocessor(row) # type: ignore[assignment] - yield row # type: ignore[misc] + row = preprocessor(row) + + yield self.finalizer(row) except StopIteration: raise # Stop iteration when any dataset is exhausted except Exception as err: # noqa: BLE001 # Exception logged @@ -140,6 +142,7 @@ def __init__( data_samples: int, processor_factory: Callable[[], PreTrainedTokenizerBase], preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor], + finalizer: DatasetFinalizer[DataT], collator: Callable, sampler: Sampler[int] | Literal["shuffle"] | None = None, num_workers: int | None = 1, @@ -152,6 +155,7 @@ def __init__( data_samples=data_samples, processor_factory=processor_factory, preprocessors=preprocessors, + finalizer=finalizer, random_seed=random_seed, ) self._info: dict[str, Any] = { @@ -161,6 +165,7 @@ def __init__( "preprocessors": [ preprocessor.__class__.__name__ for preprocessor in preprocessors ], + "finalizer": finalizer.__class__.__name__, "collator": collator.__class__.__name__, "sampler": str(sampler), "num_workers": num_workers, diff --git a/src/guidellm/data/preprocessors/__init__.py b/src/guidellm/data/preprocessors/__init__.py index 6d6e722d8..094118176 100644 --- a/src/guidellm/data/preprocessors/__init__.py +++ b/src/guidellm/data/preprocessors/__init__.py @@ -1,10 +1,4 @@ -from .formatters import ( - GenerativeAudioTranscriptionRequestFormatter, - GenerativeAudioTranslationRequestFormatter, - GenerativeChatCompletionsRequestFormatter, - GenerativeTextCompletionsRequestFormatter, - RequestFormatter, -) +from .encoders import AudioEncoder, ImageEncoder, PreprocessEncoder, VideoEncoder from .mappers import GenerativeColumnMapper from .preprocessor import ( DataDependentPreprocessor, @@ -13,15 +7,12 @@ ) __all__ = [ - "ColumnMapper", - "ColumnMapperRegistry", + "AudioEncoder", "DataDependentPreprocessor", "DatasetPreprocessor", - "GenerativeAudioTranscriptionRequestFormatter", - "GenerativeAudioTranslationRequestFormatter", - "GenerativeChatCompletionsRequestFormatter", "GenerativeColumnMapper", - "GenerativeTextCompletionsRequestFormatter", + "ImageEncoder", + "PreprocessEncoder", "PreprocessorRegistry", - "RequestFormatter", + "VideoEncoder", ] diff --git a/src/guidellm/data/preprocessors/encoders.py b/src/guidellm/data/preprocessors/encoders.py new file mode 100644 index 000000000..5cd230541 --- /dev/null +++ b/src/guidellm/data/preprocessors/encoders.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import Any + +from guidellm.data.preprocessors.preprocessor import ( + DatasetPreprocessor, + PreprocessorRegistry, +) + +__all__ = ["AudioEncoder", "ImageEncoder", "PreprocessEncoder", "VideoEncoder"] + + +class PreprocessEncoder(DatasetPreprocessor): + @staticmethod + def encode_audio(*args, **kwargs): + from guidellm.extras.audio import encode_audio + + return encode_audio(*args, **kwargs) + + @staticmethod + def encode_image(*args, **kwargs): + from guidellm.extras.vision import encode_image + + return encode_image(*args, **kwargs) + + @staticmethod + def encode_video(*args, **kwargs): + from guidellm.extras.vision import encode_video + + return encode_video(*args, **kwargs) + + +@PreprocessorRegistry.register("encode_audio") +class AudioEncoder(PreprocessEncoder): + def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None: + self.encode_audio_kwargs = ( + encode_kwargs.get("audio", {}) if encode_kwargs else {} + ) + + def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]: + if columns.get("audio_column"): + encoded_audio = [] + for audio in columns["audio_column"]: + if not audio: + continue + + encoded_audio.append( + self.encode_audio(audio, b64encode=True, **self.encode_audio_kwargs) + ) + columns["audio_column"] = encoded_audio + + return columns + + +@PreprocessorRegistry.register("encode_image") +class ImageEncoder(PreprocessEncoder): + def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None: + self.encode_image_kwargs = ( + encode_kwargs.get("image", {}) if encode_kwargs else {} + ) + + def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]: + if columns.get("image_column"): + encoded_images = [] + for image in columns["image_column"]: + if not image: + continue + + encoded_images.append( + self.encode_image(image, **self.encode_image_kwargs) + ) + columns["image_column"] = encoded_images + + return columns + + +@PreprocessorRegistry.register("encode_video") +class VideoEncoder(PreprocessEncoder): + def __init__(self, encode_kwargs: dict[str, Any] | None = None) -> None: + self.encode_video_kwargs = ( + encode_kwargs.get("video", {}) if encode_kwargs else {} + ) + + def __call__(self, columns: dict[str, list[Any]]) -> dict[str, list[Any]]: + if columns.get("video_column"): + encoded_videos = [] + for video in columns["video_column"]: + if not video: + continue + + encoded_videos.append( + self.encode_video(video, **self.encode_video_kwargs) + ) + columns["video_column"] = encoded_videos + + return columns diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py deleted file mode 100644 index 608128a64..000000000 --- a/src/guidellm/data/preprocessors/formatters.py +++ /dev/null @@ -1,404 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from guidellm.data.preprocessors.preprocessor import ( - DatasetPreprocessor, - PreprocessorRegistry, -) -from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics - -__all__ = [ - "GenerativeAudioTranscriptionRequestFormatter", - "GenerativeAudioTranslationRequestFormatter", - "GenerativeChatCompletionsRequestFormatter", - "GenerativeTextCompletionsRequestFormatter", - "RequestFormatter", -] - - -class RequestFormatter(DatasetPreprocessor): - def __init__(self, model: str, **_kwargs): - self.model = model - - @staticmethod - def encode_audio(*args, **kwargs): - from guidellm.extras.audio import encode_audio - - return encode_audio(*args, **kwargs) - - @staticmethod - def encode_image(*args, **kwargs): - from guidellm.extras.vision import encode_image - - return encode_image(*args, **kwargs) - - @staticmethod - def encode_video(*args, **kwargs): - from guidellm.extras.vision import encode_video - - return encode_video(*args, **kwargs) - - -@PreprocessorRegistry.register("text_completions") -class GenerativeTextCompletionsRequestFormatter(RequestFormatter): - def __init__( - self, - model: str, - extras: dict[str, Any] | GenerationRequestArguments | None = None, - stream: bool = True, - max_tokens: int | None = None, - max_completion_tokens: int | None = None, - ): - self.model: str = model - self.extras = ( - GenerationRequestArguments(**extras) - if extras and isinstance(extras, dict) - else extras - ) - self.stream: bool = stream - self.max_tokens: int | None = max_tokens or max_completion_tokens - - def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest: - """ - :param columns: A dict of GenerativeDatasetColumnType to Any - """ - arguments: GenerationRequestArguments = GenerationRequestArguments() - arguments.body = {} # The type checker works better setting this field here - input_metrics = UsageMetrics() - output_metrics = UsageMetrics() - - # Add model - if self.model is not None: - arguments.body["model"] = self.model - - # Configure streaming - if self.stream: - arguments.stream = True - arguments.body["stream"] = True - arguments.body["stream_options"] = {"include_usage": True} - - # Handle output tokens - if output_tokens := sum( - count for count in columns.get("output_tokens_count_column", []) if count - ): - output_metrics.text_tokens = output_tokens - arguments.body["max_tokens"] = output_tokens - arguments.body["stop"] = None - arguments.body["ignore_eos"] = True - elif self.max_tokens is not None: - arguments.body["max_tokens"] = self.max_tokens - - # Handle prompt tokens - if prompt_tokens := sum( - count for count in columns.get("prompt_tokens_count_column", []) if count - ): - input_metrics.text_tokens = prompt_tokens - - # Apply extra arguments - if self.extras: - arguments.model_combine(self.extras) - - # Build prompt - prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) - text = "".join(txt for txt in columns.get("text_column", []) if txt) - if prefix or text: - prompt = prefix + text - arguments.body["prompt"] = prompt - input_metrics.add_text_metrics(prompt) - - return GenerationRequest( - request_type="text_completions", - arguments=arguments, - input_metrics=input_metrics, - output_metrics=output_metrics, - ) - - -@PreprocessorRegistry.register("chat_completions") -class GenerativeChatCompletionsRequestFormatter(RequestFormatter): - def __init__( - self, - model: str, - extras: dict[str, Any] | GenerationRequestArguments | None = None, - stream: bool = True, - max_tokens: int | None = None, - max_completion_tokens: int | None = None, - encode_kwargs: dict[str, Any] | None = None, - ): - self.model = model - self.extras = ( - GenerationRequestArguments(**extras) - if extras and isinstance(extras, dict) - else extras - ) - self.stream = stream - self.max_completion_tokens = max_tokens or max_completion_tokens - self.encode_image_kwargs = ( - encode_kwargs.get("image", {}) if encode_kwargs else {} - ) - self.encode_video_kwargs = ( - encode_kwargs.get("video", {}) if encode_kwargs else {} - ) - self.encode_audio_kwargs = ( - encode_kwargs.get("audio", {}) if encode_kwargs else {} - ) - - def __call__( # noqa: C901, PLR0912, PLR0915 - self, columns: dict[str, list[Any]] - ) -> GenerationRequest: - """ - :param columns: A dict of GenerativeDatasetColumnType to Any - """ - arguments = GenerationRequestArguments() - arguments.body = {} # The type checker works best with body assigned here - input_metrics = UsageMetrics() - output_metrics = UsageMetrics() - - # Add model - if self.model is not None: - arguments.body["model"] = self.model - - # Configure streaming - if self.stream: - arguments.stream = True - arguments.body["stream"] = True - arguments.body["stream_options"] = {"include_usage": True} - - # Handle output tokens - if output_tokens := sum( - count for count in columns.get("output_tokens_count_column", []) if count - ): - output_metrics.text_tokens = output_tokens - arguments.body.update( - { - "max_completion_tokens": output_tokens, - "stop": None, - "ignore_eos": True, - } - ) - elif self.max_completion_tokens is not None: - arguments.body["max_completion_tokens"] = self.max_completion_tokens - - # Handle prompt tokens - if prompt_tokens := sum( - count for count in columns.get("prompt_tokens_count_column", []) if count - ): - input_metrics.text_tokens = prompt_tokens - - # Apply extra arguments - if self.extras: - arguments.model_combine(self.extras) - - # Build messages - arguments.body["messages"] = [] - - for prefix in columns.get("prefix_column", []): - if not prefix: - continue - - input_metrics.add_text_metrics(prefix) - arguments.body["messages"].append({"role": "system", "content": prefix}) - - for text in columns.get("text_column", []): - if not text: - continue - - input_metrics.add_text_metrics(text) - - arguments.body["messages"].append( - {"role": "user", "content": [{"type": "text", "text": text}]} - ) - - for image in columns.get("image_column", []): - if not image: - continue - - image_dict = self.encode_image(image, **self.encode_image_kwargs) - if (image_pixels := image_dict.get("image_pixels")) is not None: - input_metrics.image_pixels = ( - input_metrics.image_pixels or 0 - ) + image_pixels - if (image_bytes := image_dict.get("image_bytes")) is not None: - input_metrics.image_bytes = ( - input_metrics.image_bytes or 0 - ) + image_bytes - - arguments.body["messages"].append( - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": image_dict.get("image")} - ], - } - ) - - for video in columns.get("video_column", []): - if not video: - continue - - video_dict = self.encode_video(video, **self.encode_video_kwargs) - if (video_frames := video_dict.get("video_frames")) is not None: - input_metrics.video_frames = ( - input_metrics.video_frames or 0 - ) + video_frames - if (video_seconds := video_dict.get("video_seconds")) is not None: - input_metrics.video_seconds = ( - input_metrics.video_seconds or 0.0 - ) + video_seconds - if (video_bytes := video_dict.get("video_bytes")) is not None: - input_metrics.video_bytes = ( - input_metrics.video_bytes or 0 - ) + video_bytes - - arguments.body["messages"].append( - { - "role": "user", - "content": [ - {"type": "video_url", "video_url": video_dict.get("video")} - ], - } - ) - - for audio in columns.get("audio_column", []): - if not audio: - continue - - audio_dict = self.encode_audio( - audio, b64encode=True, **self.encode_audio_kwargs - ) - if (audio_samples := audio_dict.get("audio_samples")) is not None: - input_metrics.audio_samples = ( - input_metrics.audio_samples or 0 - ) + audio_samples - if (audio_seconds := audio_dict.get("audio_seconds")) is not None: - input_metrics.audio_seconds = ( - input_metrics.audio_seconds or 0.0 - ) + audio_seconds - if (audio_bytes := audio_dict.get("audio_bytes")) is not None: - input_metrics.audio_bytes = ( - input_metrics.audio_bytes or 0 - ) + audio_bytes - - arguments.body["messages"].append( - { - "role": "user", - "content": [ - { - "type": "input_audio", - "input_audio": { - "data": audio_dict.get("audio"), - "format": audio_dict.get("format"), - }, - } - ], - } - ) - - return GenerationRequest( - request_type="chat_completions", - arguments=arguments, - input_metrics=input_metrics, - output_metrics=output_metrics, - ) - - -@PreprocessorRegistry.register("audio_transcriptions") -class GenerativeAudioTranscriptionRequestFormatter(RequestFormatter): - def __init__( - self, - model: str, - extras: dict[str, Any] | GenerationRequestArguments | None = None, - stream: bool = True, - encode_kwargs: dict[str, Any] | None = None, - ): - self.model = model - self.extras = ( - GenerationRequestArguments(**extras) - if extras and isinstance(extras, dict) - else extras - ) - self.stream = stream - self.encode_audio_kwargs = encode_kwargs or {} - - def __call__( # noqa: C901 - self, columns: dict[str, list[Any]] - ) -> GenerationRequest: - arguments = GenerationRequestArguments(files={}) - arguments.body = {} # The type checker works best with body assigned here - input_metrics = UsageMetrics() - output_metrics = UsageMetrics() - - # Add model - if self.model is not None: - arguments.body["model"] = self.model - - # Configure streaming - if self.stream: - arguments.stream = True - arguments.body["stream"] = True - arguments.body["stream_options"] = {"include_usage": True} - - # Handle output tokens - if output_tokens := sum( - count for count in columns.get("output_tokens_count_column", []) if count - ): - output_metrics.text_tokens = output_tokens - - # Handle prompt tokens (for audio duration tracking) - if prompt_tokens := sum( - count for count in columns.get("prompt_tokens_count_column", []) if count - ): - input_metrics.text_tokens = prompt_tokens - - # Apply extra arguments - if self.extras: - arguments.model_combine(self.extras) - - # Build audio input - audio_columns = columns.get("audio_column", []) - if len(audio_columns) != 1: - raise ValueError( - f"GenerativeAudioTranscriptionRequestFormatter expects exactly " - f"one audio column, but got {len(audio_columns)}." - ) - - audio_dict = self.encode_audio( - audio_columns[0], b64encode=False, **self.encode_audio_kwargs - ) - input_metrics.audio_samples = audio_dict.get("audio_samples") - input_metrics.audio_seconds = audio_dict.get("audio_seconds") - input_metrics.audio_bytes = audio_dict.get("audio_bytes") - - arguments.files = { - "file": ( - audio_dict.get("file_name", "audio_input"), - audio_dict.get("audio"), - audio_dict.get("mimetype"), - ) - } - - # Build prompt - prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) - text = "".join(txt for txt in columns.get("text_column", []) if txt) - if prefix or text: - prompt = prefix + text - arguments.body["prompt"] = prompt - input_metrics.add_text_metrics(prompt) - - return GenerationRequest( - request_type="audio_transcriptions", - arguments=arguments, - input_metrics=input_metrics, - output_metrics=output_metrics, - ) - - -@PreprocessorRegistry.register("audio_translations") -class GenerativeAudioTranslationRequestFormatter( - GenerativeAudioTranscriptionRequestFormatter -): - def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest: - result = super().__call__(columns) - result.request_type = "audio_translations" - return result diff --git a/src/guidellm/data/preprocessors/preprocessor.py b/src/guidellm/data/preprocessors/preprocessor.py index 43fe20e9e..9700ef294 100644 --- a/src/guidellm/data/preprocessors/preprocessor.py +++ b/src/guidellm/data/preprocessors/preprocessor.py @@ -4,7 +4,6 @@ from datasets import Dataset, IterableDataset -from guidellm.schemas import GenerationRequest from guidellm.utils import RegistryMixin __all__ = ["DataDependentPreprocessor", "DatasetPreprocessor", "PreprocessorRegistry"] @@ -12,7 +11,7 @@ @runtime_checkable class DatasetPreprocessor(Protocol): - def __call__(self, item: dict[str, Any]) -> GenerationRequest | dict[str, Any]: ... + def __call__(self, item: dict[str, Any]) -> dict[str, Any]: ... @runtime_checkable diff --git a/src/guidellm/schemas/request.py b/src/guidellm/schemas/request.py index a5193474c..138032f77 100644 --- a/src/guidellm/schemas/request.py +++ b/src/guidellm/schemas/request.py @@ -210,12 +210,12 @@ class GenerationRequest(StandardBaseModel): "this will be used to determine the request url." ), ) - arguments: GenerationRequestArguments = Field( + columns: dict[str, list[Any]] = Field( + default_factory=dict, description=( - "Payload for the request, structured as a dictionary of arguments to pass " - "to the respective backend method. For example, can contain " - "'json', 'headers', 'files', etc." - ) + "Columnar data associated with the request, structured as a dictionary " + "where keys are column names and values are lists of column entries." + ), ) input_metrics: UsageMetrics = Field( default_factory=UsageMetrics, diff --git a/src/guidellm/schemas/response.py b/src/guidellm/schemas/response.py index cb7ba79f4..5cdf666f9 100644 --- a/src/guidellm/schemas/response.py +++ b/src/guidellm/schemas/response.py @@ -114,9 +114,7 @@ def compile_stats( request_id=self.request_id, response_id=self.response_id, request_type=request.request_type, - request_args=str( - request.arguments.model_dump() if request.arguments else {} - ), + request_args=self.request_args, output=self.text, info=info, input_metrics=UsageMetrics(**input_metrics_dict),