Skip to content

Commit aa03d52

Browse files
authored
Merge branch 'main' into support_sharegpt
2 parents 2726d50 + 9d9392b commit aa03d52

File tree

16 files changed

+133
-42
lines changed

16 files changed

+133
-42
lines changed

src/guidellm/__main__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
try:
3434
import uvloop
3535
except ImportError:
36-
uvloop = None # type: ignore[assignment] # Optional dependency
36+
uvloop = None # type: ignore[assignment] # Optional dependency
3737

3838
from guidellm.backends import BackendType
3939
from guidellm.benchmark import (
@@ -116,6 +116,7 @@ def benchmark():
116116
)
117117
@click.option(
118118
"--scenario",
119+
"-c",
119120
type=cli_tools.Union(
120121
click.Path(
121122
exists=True,
@@ -392,8 +393,10 @@ def run(**kwargs):
392393
disable_progress = kwargs.pop("disable_progress", False)
393394

394395
try:
396+
# Only set CLI args that differ from click defaults
397+
new_kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
395398
args = BenchmarkGenerativeTextArgs.create(
396-
scenario=kwargs.pop("scenario", None), **kwargs
399+
scenario=new_kwargs.pop("scenario", None), **new_kwargs
397400
)
398401
except ValidationError as err:
399402
# Translate pydantic valdation error to click argument error

src/guidellm/benchmark/benchmarker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import uuid
1414
from abc import ABC
1515
from collections.abc import AsyncIterator, Iterable
16-
from typing import Generic
16+
from typing import Any, Generic
1717

1818
from guidellm.benchmark.profile import Profile
1919
from guidellm.benchmark.progress import BenchmarkerProgress

src/guidellm/benchmark/schemas.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,17 @@
2323
from typing import Any, ClassVar, Literal, TypeVar, cast
2424

2525
import yaml
26-
from pydantic import ConfigDict, Field, computed_field, model_serializer
26+
from pydantic import (
27+
AliasChoices,
28+
AliasGenerator,
29+
ConfigDict,
30+
Field,
31+
ValidationError,
32+
ValidatorFunctionWrapHandler,
33+
computed_field,
34+
field_validator,
35+
model_serializer,
36+
)
2737
from torch.utils.data import Sampler
2838
from transformers import PreTrainedTokenizerBase
2939

@@ -1142,7 +1152,8 @@ def update_estimate(
11421152
)
11431153
request_duration = (
11441154
(request_end_time - request_start_time)
1145-
if request_end_time and request_start_time else None
1155+
if request_end_time and request_start_time
1156+
else None
11461157
)
11471158

11481159
# Always track concurrency
@@ -1669,7 +1680,7 @@ def compile(
16691680
estimated_state: EstimatedBenchmarkState,
16701681
scheduler_state: SchedulerState,
16711682
profile: Profile,
1672-
requests: Iterable,
1683+
requests: Iterable, # noqa: ARG003
16731684
backend: BackendInterface,
16741685
environment: Environment,
16751686
strategy: SchedulingStrategy,
@@ -1787,9 +1798,8 @@ def create(
17871798
scenario_data = scenario_data["args"]
17881799
constructor_kwargs.update(scenario_data)
17891800

1790-
for key, value in kwargs.items():
1791-
if value != cls.get_default(key):
1792-
constructor_kwargs[key] = value
1801+
# Apply overrides from kwargs
1802+
constructor_kwargs.update(kwargs)
17931803

17941804
return cls.model_validate(constructor_kwargs)
17951805

@@ -1818,13 +1828,19 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18181828
else:
18191829
return factory({}) # type: ignore[call-arg] # Confirmed correct at runtime by code above
18201830

1821-
1822-
18231831
model_config = ConfigDict(
18241832
extra="ignore",
18251833
use_enum_values=True,
18261834
from_attributes=True,
18271835
arbitrary_types_allowed=True,
1836+
validate_by_alias=True,
1837+
validate_by_name=True,
1838+
alias_generator=AliasGenerator(
1839+
# Support field names with hyphens
1840+
validation_alias=lambda field_name: AliasChoices(
1841+
field_name, field_name.replace("_", "-")
1842+
),
1843+
),
18281844
)
18291845

18301846
# Required
@@ -1838,7 +1854,7 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18381854
profile: StrategyType | ProfileType | Profile = Field(
18391855
default="sweep", description="Benchmark profile or scheduling strategy type"
18401856
)
1841-
rate: float | list[float] | None = Field(
1857+
rate: list[float] | None = Field(
18421858
default=None, description="Request rate(s) for rate-based scheduling"
18431859
)
18441860
# Backend configuration
@@ -1871,6 +1887,12 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
18711887
data_request_formatter: DatasetPreprocessor | dict[str, str] | str = Field(
18721888
default="chat_completions",
18731889
description="Request formatting preprocessor or template name",
1890+
validation_alias=AliasChoices(
1891+
"data_request_formatter",
1892+
"data-request-formatter",
1893+
"request_type",
1894+
"request-type",
1895+
),
18741896
)
18751897
data_collator: Callable | Literal["generative"] | None = Field(
18761898
default="generative", description="Data collator for batch processing"
@@ -1931,6 +1953,26 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
19311953
default=None, description="Maximum global error rate (0-1) before stopping"
19321954
)
19331955

1956+
@field_validator("data", "data_args", "rate", mode="wrap")
1957+
@classmethod
1958+
def single_to_list(
1959+
cls, value: Any, handler: ValidatorFunctionWrapHandler
1960+
) -> list[Any]:
1961+
"""
1962+
Ensures field is always a list.
1963+
1964+
:param value: Input value for the 'data' field
1965+
:return: List of data sources
1966+
"""
1967+
try:
1968+
return handler(value)
1969+
except ValidationError as err:
1970+
# If validation fails, try wrapping the value in a list
1971+
if err.errors()[0]["type"] == "list_type":
1972+
return handler([value])
1973+
else:
1974+
raise
1975+
19341976
@model_serializer
19351977
def serialize_model(self):
19361978
"""

src/guidellm/data/deserializers/deserializer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,10 @@ def _deserialize_with_registered_deserializers(
107107

108108
if len(errors) > 0:
109109
err_msgs = ""
110+
110111
def sort_key(item):
111112
return (isinstance(item[1], DataNotSupportedError), item[0])
113+
112114
for key, err in sorted(errors.items(), key=sort_key):
113115
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
114116
raise ValueError(
@@ -141,4 +143,3 @@ def _deserialize_with_specified_deserializer(
141143
random_seed=random_seed,
142144
**data_kwargs,
143145
)
144-

src/guidellm/data/deserializers/synthetic.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import yaml
1010
from datasets import Features, IterableDataset, Value
1111
from faker import Faker
12-
from pydantic import ConfigDict, Field, model_validator
12+
from pydantic import ConfigDict, Field, ValidationError, model_validator
1313
from transformers import PreTrainedTokenizerBase
1414

1515
from guidellm.data.deserializers.deserializer import (
@@ -242,6 +242,10 @@ def __call__(
242242
if (config := self._load_config_str(data)) is not None:
243243
return self(config, processor_factory, random_seed, **data_kwargs)
244244

245+
# Try to parse dict-like data directly
246+
if (config := self._load_config_dict(data)) is not None:
247+
return self(config, processor_factory, random_seed, **data_kwargs)
248+
245249
if not isinstance(data, SyntheticTextDatasetConfig):
246250
raise DataNotSupportedError(
247251
"Unsupported data for SyntheticTextDatasetDeserializer, "
@@ -266,6 +270,15 @@ def __call__(
266270
),
267271
)
268272

273+
def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None:
274+
if not isinstance(data, dict | list):
275+
return None
276+
277+
try:
278+
return SyntheticTextDatasetConfig.model_validate(data)
279+
except ValidationError:
280+
return None
281+
269282
def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None:
270283
if (not isinstance(data, str) and not isinstance(data, Path)) or (
271284
not Path(data).is_file()

src/guidellm/data/loaders.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
__all__ = ["DataLoader", "DatasetsIterator"]
1818

1919

20-
2120
class DatasetsIterator(TorchIterableDataset):
2221
def __init__(
2322
self,

src/guidellm/data/preprocessors/formatters.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ def __init__(
5656
self.stream: bool = stream
5757
self.max_tokens: int | None = max_tokens or max_completion_tokens
5858

59-
def __call__(
60-
self, columns: dict[str, list[Any]]
61-
) -> GenerationRequest:
59+
def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest:
6260
"""
6361
:param columns: A dict of GenerativeDatasetColumnType to Any
6462
"""
@@ -396,9 +394,7 @@ def __call__( # noqa: C901
396394
class GenerativeAudioTranslationRequestFormatter(
397395
GenerativeAudioTranscriptionRequestFormatter
398396
):
399-
def __call__(
400-
self, columns: dict[str, list[Any]]
401-
) -> GenerationRequest:
397+
def __call__(self, columns: dict[str, list[Any]]) -> GenerationRequest:
402398
result = super().__call__(columns)
403399
result.request_type = "audio_translations"
404400
return result

src/guidellm/data/preprocessors/mappers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ def __init__(
167167
dict[GenerativeDatasetColumnType, list[tuple[int, str]]] | None
168168
)
169169

170-
def __call__(
171-
self, row: dict[str, Any]
172-
) -> dict[str, list[Any]]:
170+
def __call__(self, row: dict[str, Any]) -> dict[str, list[Any]]:
173171
if self.datasets_column_mappings is None:
174172
raise ValueError("DefaultGenerativeColumnMapper not setup with data.")
175173

src/guidellm/data/preprocessors/preprocessor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313
@runtime_checkable
1414
class DatasetPreprocessor(Protocol):
15-
def __call__(self, item: dict[str, Any]) -> (
16-
GenerationRequest | dict[str, Any]): ...
15+
def __call__(self, item: dict[str, Any]) -> GenerationRequest | dict[str, Any]: ...
1716

1817

1918
@runtime_checkable

src/guidellm/preprocess/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def process_dataset(
238238
prompt_tokens: str | Path,
239239
output_tokens: str | Path,
240240
processor_args: dict[str, Any] | None = None,
241-
data_args: dict[str, Any] | None = None,
241+
data_args: dict[str, Any] | None = None, # noqa: ARG001
242242
short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
243243
pad_char: str | None = None,
244244
concat_delimiter: str | None = None,

0 commit comments

Comments
 (0)