Skip to content

Commit 34e190a

Browse files
committed
Rework entrypoint to match new preprocess and finalizer flow
Signed-off-by: Samuel Monson <[email protected]>
1 parent 133c2ec commit 34e190a

File tree

4 files changed

+124
-112
lines changed

4 files changed

+124
-112
lines changed

src/guidellm/__main__.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,13 @@ def benchmark():
178178
# Data configuration
179179
@click.option(
180180
"--request-type",
181-
default=BenchmarkGenerativeTextArgs.get_default("data_request_formatter"),
181+
default=BenchmarkGenerativeTextArgs.get_default("request_type"),
182182
type=click.Choice(list(get_literal_vals(GenerativeRequestType))),
183183
help=(
184184
f"Request type to create for each data sample. "
185185
f"Options: {', '.join(get_literal_vals(GenerativeRequestType))}."
186186
),
187187
)
188-
@click.option(
189-
"--request-formatter-kwargs",
190-
default=None,
191-
callback=cli_tools.parse_json,
192-
help="JSON string of arguments to pass to the request formatter.",
193-
)
194188
@click.option(
195189
"--processor",
196190
default=BenchmarkGenerativeTextArgs.get_default("processor"),
@@ -223,10 +217,17 @@ def benchmark():
223217
),
224218
)
225219
@click.option(
226-
"--data-column-mapper",
227-
default=BenchmarkGenerativeTextArgs.get_default("data_column_mapper"),
220+
"--data-preprocessors",
221+
default=BenchmarkGenerativeTextArgs.get_default("data_preprocessors"),
222+
callback=cli_tools.parse_json,
223+
multiple=True,
224+
help="JSON string of preprocessors to apply to the dataset.",
225+
)
226+
@click.option(
227+
"--data-finalizer",
228+
default=BenchmarkGenerativeTextArgs.get_default("data_finalizer"),
228229
callback=cli_tools.parse_json,
229-
help="JSON string of column mappings to apply to the dataset.",
230+
help="JSON string of finalizer to convert dataset rows to requests.",
230231
)
231232
@click.option(
232233
"--data-sampler",
@@ -386,18 +387,6 @@ def run(**kwargs):
386387
# Only set CLI args that differ from click defaults
387388
kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
388389

389-
# Handle remapping for request params
390-
request_type = kwargs.pop("request_type", None)
391-
request_formatter_kwargs = kwargs.pop("request_formatter_kwargs", None)
392-
if request_type is not None:
393-
kwargs["data_request_formatter"] = (
394-
request_type
395-
if not request_formatter_kwargs
396-
else {"request_type": request_type, **request_formatter_kwargs}
397-
)
398-
elif request_formatter_kwargs is not None:
399-
kwargs["data_request_formatter"] = request_formatter_kwargs
400-
401390
# Handle output path remapping
402391
if (output_path := kwargs.pop("output_path", None)) is not None:
403392
path = Path(output_path)

src/guidellm/benchmark/entrypoints.py

Lines changed: 74 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from collections.abc import Callable, Mapping, MutableMapping
1515
from pathlib import Path
16-
from typing import Any, Literal
16+
from typing import Any, Literal, TypeVar
1717

1818
from torch.utils.data import Sampler
1919
from transformers import PreTrainedTokenizerBase
@@ -36,21 +36,21 @@
3636
from guidellm.benchmark.schemas.base import TransientPhaseConfig
3737
from guidellm.data import (
3838
DataLoader,
39+
DatasetFinalizer,
3940
DatasetPreprocessor,
41+
FinalizerRegistry,
4042
GenerativeRequestCollator,
4143
PreprocessorRegistry,
4244
ProcessorFactory,
43-
RequestFormatter,
4445
)
45-
from guidellm.data.preprocessors import GenerativeColumnMapper
4646
from guidellm.scheduler import (
4747
ConstraintInitializer,
4848
NonDistributedEnvironment,
4949
StrategyType,
5050
)
5151
from guidellm.schemas import GenerationRequest, GenerationResponse
52-
from guidellm.settings import settings
5352
from guidellm.utils import Console, InfoMixin
53+
from guidellm.utils.registry import RegistryMixin
5454

5555
__all__ = [
5656
"benchmark_generative_text",
@@ -178,19 +178,66 @@ async def resolve_processor(
178178
return processor
179179

180180

181+
BaseTypeT = TypeVar("BaseTypeT")
182+
183+
184+
def resolve_item_from_registry(
185+
base_type: type[BaseTypeT],
186+
registry: type[RegistryMixin],
187+
item: Any,
188+
extras: dict[str, Any] | None = None,
189+
) -> BaseTypeT:
190+
"""
191+
Resolve an item from a registry, instantiating it if necessary.
192+
193+
:param base_type: The expected base type of the item
194+
:param item: The item to resolve, either an instance or a string identifier
195+
:param registry: The registry to use for resolving string identifiers
196+
:return: The resolved item as an instance of the base type
197+
:raises ValueError: If the item cannot be resolved from the registry
198+
:raises TypeError: If the resolved item is not of the expected base type
199+
"""
200+
if isinstance(item, base_type):
201+
return item
202+
else:
203+
if isinstance(item, str):
204+
item_type = item
205+
kwargs = {}
206+
else:
207+
item_dict = dict(item)
208+
item_type = item_dict.pop("type", None)
209+
if item_type is None:
210+
raise ValueError(
211+
f"Item dictionary must contain a 'type' key to resolve from "
212+
f"{registry.__class__.__name__}."
213+
)
214+
kwargs = item_dict
215+
216+
if (item_class := registry.get_registered_object(item_type)) is None:
217+
raise ValueError(
218+
f"Item type '{item_type}' is not registered in the "
219+
f"{registry.__class__.__name__}."
220+
)
221+
if not issubclass(item_class, base_type):
222+
raise TypeError(
223+
f"Resolved item type '{item_type}' is not a subclass of "
224+
f"{base_type.__name__}."
225+
)
226+
if extras:
227+
kwargs.update(extras)
228+
return item_class(**kwargs)
229+
230+
181231
async def resolve_request_loader(
182232
data: list[Any],
183233
model: str,
234+
request_type: str,
184235
data_args: list[dict[str, Any]] | None,
185236
data_samples: int,
186237
processor: ProcessorInputT | None,
187238
processor_args: dict[str, Any] | None,
188-
data_column_mapper: (
189-
DatasetPreprocessor
190-
| dict[str, str | list[str]]
191-
| Literal["generative_column_mapper"]
192-
),
193-
data_request_formatter: (RequestFormatter | dict[str, str] | str),
239+
data_preprocessors: list[DatasetPreprocessor | dict[str, str | list[str]] | str],
240+
data_finalizer: (DatasetFinalizer | dict[str, Any] | str),
194241
data_collator: Callable | Literal["generative"] | None,
195242
data_sampler: Sampler[int] | Literal["shuffle"] | None,
196243
data_num_workers: int | None,
@@ -232,54 +279,22 @@ async def resolve_request_loader(
232279
else None
233280
)
234281

235-
data_column_mapper_instance: DatasetPreprocessor
236-
if isinstance(data_column_mapper, DatasetPreprocessor):
237-
data_column_mapper_instance = data_column_mapper
238-
else:
239-
column_mappings = (
240-
data_column_mapper if isinstance(data_column_mapper, dict) else None
241-
)
242-
data_column_mapper_instance = GenerativeColumnMapper(
243-
column_mappings=column_mappings # type: ignore[arg-type]
244-
)
245-
246-
data_request_formatter_instance: RequestFormatter
247-
if isinstance(data_request_formatter, RequestFormatter):
248-
data_request_formatter_instance = data_request_formatter
249-
else:
250-
if isinstance(data_request_formatter, str):
251-
request_type = data_request_formatter
252-
formatter_kwargs: dict[str, Any] = {}
253-
else:
254-
# Extract request_type from formatter dictionary
255-
formatter_dict = dict(data_request_formatter)
256-
request_type = formatter_dict.pop("request_type", settings.preferred_route)
257-
formatter_kwargs = formatter_dict
258-
259-
if (
260-
formatter_class := PreprocessorRegistry.get_registered_object(request_type)
261-
) is None:
262-
raise ValueError(
263-
f"Request formatter '{request_type}' is not registered in the "
264-
f"PreprocessorRegistry."
265-
)
266-
if not issubclass(formatter_class, RequestFormatter):
267-
raise TypeError(
268-
f"Request formatter '{request_type}' is not a subclass of "
269-
f"RequestFormatter."
270-
)
271-
272-
data_request_formatter_instance = formatter_class(
273-
model=model,
274-
**formatter_kwargs,
275-
)
276-
277-
# Cast to proper types for the DataLoader preprocessors list
278282
preprocessors_list: list[DatasetPreprocessor] = [
279-
data_column_mapper_instance,
280-
data_request_formatter_instance,
283+
resolve_item_from_registry(
284+
DatasetPreprocessor, # type: ignore [type-abstract]
285+
PreprocessorRegistry,
286+
preprocessor,
287+
)
288+
for preprocessor in data_preprocessors
281289
]
282290

291+
finalizer_instance = resolve_item_from_registry(
292+
DatasetFinalizer, # type: ignore [type-abstract]
293+
FinalizerRegistry,
294+
data_finalizer,
295+
extras={"request_type": request_type},
296+
)
297+
283298
request_loader: DataLoader[GenerationRequest] = DataLoader(
284299
data=data,
285300
data_args=data_args,
@@ -289,6 +304,7 @@ async def resolve_request_loader(
289304
processor_args=processor_args,
290305
),
291306
preprocessors=preprocessors_list,
307+
finalizer=finalizer_instance,
292308
collator=(
293309
data_collator if callable(data_collator) else GenerativeRequestCollator()
294310
),
@@ -460,12 +476,13 @@ async def benchmark_generative_text(
460476
request_loader = await resolve_request_loader(
461477
data=args.data,
462478
model=model,
479+
request_type=args.request_type,
463480
data_args=args.data_args,
464481
data_samples=args.data_samples,
465482
processor=processor,
466483
processor_args=args.processor_args,
467-
data_column_mapper=args.data_column_mapper,
468-
data_request_formatter=args.data_request_formatter,
484+
data_preprocessors=args.data_preprocessors,
485+
data_finalizer=args.data_finalizer,
469486
data_collator=args.data_collator,
470487
data_sampler=args.data_sampler,
471488
data_num_workers=args.data_num_workers,

src/guidellm/benchmark/schemas/generative/entrypoints.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
from guidellm.benchmark.profiles import Profile, ProfileType
3636
from guidellm.benchmark.scenarios import get_builtin_scenarios
3737
from guidellm.benchmark.schemas.base import TransientPhaseConfig
38-
from guidellm.data import DatasetPreprocessor, RequestFormatter
38+
from guidellm.data import DatasetFinalizer, DatasetPreprocessor
3939
from guidellm.scheduler import StrategyType
4040
from guidellm.schemas import StandardBaseModel
41+
from guidellm.settings import settings
4142

4243
__all__ = ["BenchmarkGenerativeTextArgs"]
4344

@@ -179,6 +180,13 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
179180
backend_kwargs: dict[str, Any] | None = Field(
180181
default=None, description="Additional backend configuration arguments"
181182
)
183+
request_type: str = Field(
184+
default_factory=lambda: settings.preferred_route,
185+
description=(
186+
"Request type for backend operations;"
187+
" shorthand for backend_kwargs['request_type']"
188+
),
189+
)
182190
model: str | None = Field(default=None, description="Model identifier for backend")
183191
# Data configuration
184192
processor: str | Path | PreTrainedTokenizerBase | None = Field(
@@ -194,23 +202,21 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
194202
data_samples: int = Field(
195203
default=-1, description="Number of samples to use from datasets (-1 for all)"
196204
)
197-
data_column_mapper: (
198-
DatasetPreprocessor
199-
| dict[str, str | list[str]]
200-
| Literal["generative_column_mapper"]
201-
) = Field(
202-
default="generative_column_mapper",
203-
description="Column mapping preprocessor for dataset fields",
205+
# TODO: Make it easy to cutomize preprocessors without editing the full list
206+
data_preprocessors: list[DatasetPreprocessor | dict[str, str | list[str]] | str] = (
207+
Field(
208+
default_factory=lambda: [ # type: ignore [arg-type]
209+
"generative_column_mapper",
210+
"encode_audio",
211+
"encode_image",
212+
"encode_video",
213+
],
214+
description="List of dataset preprocessors to apply in order",
215+
)
204216
)
205-
data_request_formatter: RequestFormatter | dict[str, Any] | str = Field(
206-
default="chat_completions",
207-
description="Request formatting preprocessor or template name",
208-
validation_alias=AliasChoices(
209-
"data_request_formatter",
210-
"data-request-formatter",
211-
"request_type",
212-
"request-type",
213-
),
217+
data_finalizer: DatasetFinalizer | str | dict[str, Any] = Field(
218+
default="generative_text_finalizer",
219+
description="Finalizer for preparing data samples into requests",
214220
)
215221
data_collator: Callable | Literal["generative"] | None = Field(
216222
default="generative", description="Data collator for batch processing"
@@ -284,7 +290,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
284290
default=None, description="Maximum global error rate (0-1) before stopping"
285291
)
286292

287-
@field_validator("data", "data_args", "rate", mode="wrap")
293+
@field_validator("data", "data_args", "rate", "data_preprocessors", mode="wrap")
288294
@classmethod
289295
def single_to_list(
290296
cls, value: Any, handler: ValidatorFunctionWrapHandler
@@ -323,28 +329,25 @@ def serialize_data_collator(
323329
"""Serialize data_collator to string or None."""
324330
return data_collator if isinstance(data_collator, str) else None
325331

326-
@field_serializer("data_column_mapper")
332+
@field_serializer("data_preprocessors")
327333
def serialize_data_column_mapper(
328334
self,
329-
data_column_mapper: (
330-
DatasetPreprocessor
331-
| dict[str, str | list[str]]
332-
| Literal["generative_column_mapper"]
333-
),
334-
) -> dict | str:
335+
data_preprocessors: list[
336+
DatasetPreprocessor | dict[str, str | list[str]] | str
337+
],
338+
) -> list[dict | str]:
335339
"""Serialize data_column_mapper to dict or string."""
336-
return data_column_mapper if isinstance(data_column_mapper, dict | str) else {}
340+
return [
341+
(preprocessor if isinstance(preprocessor, dict | str) else {})
342+
for preprocessor in data_preprocessors
343+
]
337344

338-
@field_serializer("data_request_formatter")
345+
@field_serializer("data_finalizer")
339346
def serialize_data_request_formatter(
340-
self, data_request_formatter: RequestFormatter | dict[str, Any] | str
347+
self, data_finalizer: DatasetFinalizer | dict[str, Any] | str
341348
) -> dict | str:
342349
"""Serialize data_request_formatter to dict or string."""
343-
return (
344-
data_request_formatter
345-
if isinstance(data_request_formatter, dict | str)
346-
else {}
347-
)
350+
return data_finalizer if isinstance(data_finalizer, dict | str) else {}
348351

349352
@field_serializer("data_sampler")
350353
def serialize_data_sampler(

src/guidellm/data/loaders.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
data_samples: int,
143143
processor_factory: Callable[[], PreTrainedTokenizerBase],
144144
preprocessors: list[DatasetPreprocessor | DataDependentPreprocessor],
145+
finalizer: DatasetFinalizer[DataT],
145146
collator: Callable,
146147
sampler: Sampler[int] | Literal["shuffle"] | None = None,
147148
num_workers: int | None = 1,
@@ -154,6 +155,7 @@ def __init__(
154155
data_samples=data_samples,
155156
processor_factory=processor_factory,
156157
preprocessors=preprocessors,
158+
finalizer=finalizer,
157159
random_seed=random_seed,
158160
)
159161
self._info: dict[str, Any] = {
@@ -163,6 +165,7 @@ def __init__(
163165
"preprocessors": [
164166
preprocessor.__class__.__name__ for preprocessor in preprocessors
165167
],
168+
"finalizer": finalizer.__class__.__name__,
166169
"collator": collator.__class__.__name__,
167170
"sampler": str(sampler),
168171
"num_workers": num_workers,

0 commit comments

Comments
 (0)