Skip to content

Commit 9d98291

Browse files
authored
Merge branch 'main' into fix-muliple-rates
2 parents a75eb6f + b05b0e7 commit 9d98291

File tree

18 files changed

+193
-161
lines changed

18 files changed

+193
-161
lines changed

src/guidellm/benchmark/benchmarker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ async def run(
5757
backend: BackendInterface[RequestT, ResponseT],
5858
profile: Profile,
5959
environment: Environment,
60+
data: list[Any],
6061
progress: BenchmarkerProgress[BenchmarkT] | None = None,
6162
sample_requests: int | None = 20,
6263
warmup: float | None = None,
@@ -149,6 +150,7 @@ async def run(
149150
environment=environment,
150151
strategy=strategy,
151152
constraints=constraints,
153+
data=data,
152154
)
153155
if progress:
154156
await progress.on_benchmark_complete(benchmark)

src/guidellm/benchmark/entrypoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ async def benchmark_generative_text(
436436
backend=backend,
437437
profile=profile,
438438
environment=NonDistributedEnvironment(),
439+
data=args.data,
439440
progress=progress,
440441
sample_requests=args.sample_requests,
441442
warmup=args.warmup,

src/guidellm/benchmark/output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,8 @@ def _get_benchmark_status_metrics_stats(
649649
status_dist_summary: StatusDistributionSummary = getattr(
650650
benchmark.metrics, metric
651651
)
652+
if not hasattr(status_dist_summary, status):
653+
return [], []
652654
dist_summary: DistributionSummary = getattr(status_dist_summary, status)
653655

654656
headers = [
@@ -688,7 +690,7 @@ def _get_benchmark_extras_headers_and_values(
688690
values: list[str] = [
689691
benchmark.benchmarker.profile.model_dump_json(),
690692
json.dumps(benchmark.benchmarker.backend),
691-
json.dumps(benchmark.benchmarker.requests["attributes"]["data"]),
693+
json.dumps(benchmark.benchmarker.requests["data"]),
692694
]
693695

694696
if len(headers) != len(values):

src/guidellm/benchmark/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,7 @@ def compile(
16741674
environment: Environment,
16751675
strategy: SchedulingStrategy,
16761676
constraints: dict[str, dict[str, Any]],
1677+
data: list[Any],
16771678
) -> GenerativeBenchmark:
16781679
"""
16791680
Compile final generative benchmark from accumulated state.
@@ -1702,7 +1703,7 @@ def compile(
17021703
),
17031704
benchmarker=BenchmarkerDict(
17041705
profile=profile,
1705-
requests=InfoMixin.extract_from_obj(requests),
1706+
requests={"data": data},
17061707
backend=backend.info,
17071708
environment=environment.info,
17081709
),
Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
import contextlib
43
from collections.abc import Callable
54
from typing import Any, Protocol, Union, runtime_checkable
65

7-
from datasets import Dataset, IterableDataset
6+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
87
from transformers import PreTrainedTokenizerBase
98

109
from guidellm.data.utils import resolve_dataset_split
@@ -29,7 +28,7 @@ def __call__(
2928
processor_factory: Callable[[], PreTrainedTokenizerBase],
3029
random_seed: int,
3130
**data_kwargs: dict[str, Any],
32-
) -> dict[str, list]: ...
31+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
3332

3433

3534
class DatasetDeserializerFactory(
@@ -47,51 +46,16 @@ def deserialize(
4746
remove_columns: list[str] | None = None,
4847
**data_kwargs: dict[str, Any],
4948
) -> Dataset | IterableDataset:
50-
dataset = None
49+
dataset: Dataset
5150

5251
if type_ is None:
53-
errors = []
54-
# Note: There is no priority order for the deserializers, so all deserializers
55-
# must be mutually exclusive to ensure deterministic behavior.
56-
for name, deserializer in cls.registry.items():
57-
deserializer_fn: DatasetDeserializer = (
58-
deserializer() if isinstance(deserializer, type) else deserializer
59-
)
60-
61-
try:
62-
with contextlib.suppress(DataNotSupportedError):
63-
dataset = deserializer_fn(
64-
data=data,
65-
processor_factory=processor_factory,
66-
random_seed=random_seed,
67-
**data_kwargs,
68-
)
69-
except Exception as e:
70-
errors.append(e)
71-
72-
if dataset is not None:
73-
break # Found one that works. Continuing could overwrite it.
74-
75-
if dataset is None and len(errors) > 0:
76-
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77-
f"attempting to deserialize data {data}: {errors}")
78-
79-
elif deserializer := cls.get_registered_object(type_) is not None:
80-
deserializer_fn: DatasetDeserializer = (
81-
deserializer() if isinstance(deserializer, type) else deserializer
52+
dataset = cls._deserialize_with_registered_deserializers(
53+
data, processor_factory, random_seed, **data_kwargs
8254
)
8355

84-
dataset = deserializer_fn(
85-
data=data,
86-
processor_factory=processor_factory,
87-
random_seed=random_seed,
88-
**data_kwargs,
89-
)
90-
91-
if dataset is None:
92-
raise DataNotSupportedError(
93-
f"No suitable deserializer found for data {data} "
94-
f"with kwargs {data_kwargs} and deserializer type {type_}."
56+
else:
57+
dataset = cls._deserialize_with_specified_deserializer(
58+
data, type_, processor_factory, random_seed, **data_kwargs
9559
)
9660

9761
if resolve_split:
@@ -107,3 +71,74 @@ def deserialize(
10771
dataset = dataset.remove_columns(remove_columns)
10872

10973
return dataset
74+
75+
@classmethod
76+
def _deserialize_with_registered_deserializers(
77+
cls,
78+
data: Any,
79+
processor_factory: Callable[[], PreTrainedTokenizerBase],
80+
random_seed: int = 42,
81+
**data_kwargs: dict[str, Any],
82+
) -> Dataset:
83+
if cls.registry is None:
84+
raise RuntimeError("registry is None; cannot deserialize dataset")
85+
dataset: Dataset | None = None
86+
87+
errors: dict[str, Exception] = {}
88+
# Note: There is no priority order for the deserializers, so all deserializers
89+
# must be mutually exclusive to ensure deterministic behavior.
90+
for _name, deserializer in cls.registry.items():
91+
deserializer_fn: DatasetDeserializer = (
92+
deserializer() if isinstance(deserializer, type) else deserializer
93+
)
94+
95+
try:
96+
dataset = deserializer_fn(
97+
data=data,
98+
processor_factory=processor_factory,
99+
random_seed=random_seed,
100+
**data_kwargs,
101+
)
102+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
103+
errors[_name] = e
104+
105+
if dataset is not None:
106+
return dataset # Success
107+
108+
if len(errors) > 0:
109+
err_msgs = ""
110+
def sort_key(item):
111+
return (isinstance(item[1], DataNotSupportedError), item[0])
112+
for key, err in sorted(errors.items(), key=sort_key):
113+
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
114+
raise ValueError(
115+
"Data deserialization failed, likely because the input doesn't "
116+
f"match any of the input formats. See the {len(errors)} error(s) that "
117+
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
118+
)
119+
return dataset
120+
121+
@classmethod
122+
def _deserialize_with_specified_deserializer(
123+
cls,
124+
data: Any,
125+
type_: str,
126+
processor_factory: Callable[[], PreTrainedTokenizerBase],
127+
random_seed: int = 42,
128+
**data_kwargs: dict[str, Any],
129+
) -> Dataset:
130+
deserializer_from_type = cls.get_registered_object(type_)
131+
if deserializer_from_type is None:
132+
raise ValueError(f"Deserializer type '{type_}' is not registered.")
133+
if isinstance(deserializer_from_type, type):
134+
deserializer_fn = deserializer_from_type()
135+
else:
136+
deserializer_fn = deserializer_from_type
137+
138+
return deserializer_fn(
139+
data=data,
140+
processor_factory=processor_factory,
141+
random_seed=random_seed,
142+
**data_kwargs,
143+
)
144+

src/guidellm/data/deserializers/file.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __call__(
3434
processor_factory: Callable[[], PreTrainedTokenizerBase],
3535
random_seed: int,
3636
**data_kwargs: dict[str, Any],
37-
) -> dict[str, list]:
37+
) -> Dataset:
3838
_ = (processor_factory, random_seed) # Ignore unused args format errors
3939

4040
if (
41-
not isinstance(data, (str, Path))
41+
not isinstance(data, str | Path)
4242
or not (path := Path(data)).exists()
4343
or not path.is_file()
4444
or path.suffix.lower() not in {".txt", ".text"}
@@ -62,10 +62,10 @@ def __call__(
6262
processor_factory: Callable[[], PreTrainedTokenizerBase],
6363
random_seed: int,
6464
**data_kwargs: dict[str, Any],
65-
) -> dict[str, list]:
65+
) -> Dataset:
6666
_ = (processor_factory, random_seed)
6767
if (
68-
not isinstance(data, (str, Path))
68+
not isinstance(data, str | Path)
6969
or not (path := Path(data)).exists()
7070
or not path.is_file()
7171
or path.suffix.lower() != ".csv"
@@ -86,10 +86,10 @@ def __call__(
8686
processor_factory: Callable[[], PreTrainedTokenizerBase],
8787
random_seed: int,
8888
**data_kwargs: dict[str, Any],
89-
) -> dict[str, list]:
89+
) -> Dataset:
9090
_ = (processor_factory, random_seed)
9191
if (
92-
not isinstance(data, (str, Path))
92+
not isinstance(data, str | Path)
9393
or not (path := Path(data)).exists()
9494
or not path.is_file()
9595
or path.suffix.lower() not in {".json", ".jsonl"}
@@ -110,10 +110,10 @@ def __call__(
110110
processor_factory: Callable[[], PreTrainedTokenizerBase],
111111
random_seed: int,
112112
**data_kwargs: dict[str, Any],
113-
) -> dict[str, list]:
113+
) -> Dataset:
114114
_ = (processor_factory, random_seed)
115115
if (
116-
not isinstance(data, (str, Path))
116+
not isinstance(data, str | Path)
117117
or not (path := Path(data)).exists()
118118
or not path.is_file()
119119
or path.suffix.lower() != ".parquet"
@@ -134,10 +134,10 @@ def __call__(
134134
processor_factory: Callable[[], PreTrainedTokenizerBase],
135135
random_seed: int,
136136
**data_kwargs: dict[str, Any],
137-
) -> dict[str, list]:
137+
) -> Dataset:
138138
_ = (processor_factory, random_seed)
139139
if (
140-
not isinstance(data, (str, Path))
140+
not isinstance(data, str | Path)
141141
or not (path := Path(data)).exists()
142142
or not path.is_file()
143143
or path.suffix.lower() != ".arrow"
@@ -158,10 +158,10 @@ def __call__(
158158
processor_factory: Callable[[], PreTrainedTokenizerBase],
159159
random_seed: int,
160160
**data_kwargs: dict[str, Any],
161-
) -> dict[str, list]:
161+
) -> Dataset:
162162
_ = (processor_factory, random_seed)
163163
if (
164-
not isinstance(data, (str, Path))
164+
not isinstance(data, str | Path)
165165
or not (path := Path(data)).exists()
166166
or not path.is_file()
167167
or path.suffix.lower() not in {".hdf5", ".h5"}
@@ -185,7 +185,7 @@ def __call__(
185185
) -> dict[str, list]:
186186
_ = (processor_factory, random_seed)
187187
if (
188-
not isinstance(data, (str, Path))
188+
not isinstance(data, str | Path)
189189
or not (path := Path(data)).exists()
190190
or not path.is_file()
191191
or path.suffix.lower() != ".db"
@@ -209,7 +209,7 @@ def __call__(
209209
) -> dict[str, list]:
210210
_ = (processor_factory, random_seed)
211211
if (
212-
not isinstance(data, (str, Path))
212+
not isinstance(data, str | Path)
213213
or not (path := Path(data)).exists()
214214
or not path.is_file()
215215
or path.suffix.lower() != ".tar"

src/guidellm/data/deserializers/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __call__(
3636
processor_factory: Callable[[], PreTrainedTokenizerBase],
3737
random_seed: int,
3838
**data_kwargs: dict[str, Any],
39-
) -> dict[str, list]:
39+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
4040
_ = (processor_factory, random_seed)
4141

4242
if isinstance(

0 commit comments

Comments
 (0)