Skip to content

Commit 4bff34a

Browse files
committed
Convert new types to TypeAliasTypes
1 parent 2c70edd commit 4bff34a

File tree

4 files changed

+76
-46
lines changed

4 files changed

+76
-46
lines changed

src/guidellm/benchmark/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,17 @@
4646
enable_scenarios,
4747
get_builtin_scenarios,
4848
)
49+
from .types import (
50+
AggregatorInputT,
51+
DataInputT,
52+
OutputFormatT,
53+
ProcessorInputT,
54+
ProgressInputT,
55+
)
4956

5057
__all__ = [
5158
"Aggregator",
59+
"AggregatorInputT",
5260
"AggregatorState",
5361
"AsyncProfile",
5462
"Benchmark",
@@ -60,6 +68,7 @@
6068
"BenchmarkerProgressGroup",
6169
"CompilableAggregator",
6270
"ConcurrentProfile",
71+
"DataInputT",
6372
"GenerativeBenchmark",
6473
"GenerativeBenchmarkerCSV",
6574
"GenerativeBenchmarkerConsole",
@@ -73,8 +82,11 @@
7382
"GenerativeStatsProgressAggregator",
7483
"GenerativeTextScenario",
7584
"InjectExtrasAggregator",
85+
"OutputFormatT",
86+
"ProcessorInputT",
7687
"Profile",
7788
"ProfileType",
89+
"ProgressInputT",
7890
"Scenario",
7991
"SchedulerStatsAggregator",
8092
"SerializableAggregator",

src/guidellm/benchmark/entrypoints.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
from guidellm.benchmark.progress import BenchmarkerProgressGroup
2525
from guidellm.benchmark.scenario import enable_scenarios
2626
from guidellm.benchmark.types import (
27-
OutputFormatType,
28-
DataInputType,
29-
ProcessorInputType,
30-
ProgressInputType,
31-
AggregatorInputType
27+
AggregatorInputT,
28+
DataInputT,
29+
OutputFormatT,
30+
ProcessorInputT,
31+
ProgressInputT,
3232
)
3333
from guidellm.request import GenerativeRequestLoader
3434
from guidellm.scheduler import (
@@ -49,16 +49,15 @@
4949

5050
# Helper functions
5151

52+
5253
async def initialize_backend(
5354
backend: BackendType | Backend,
5455
target: str,
5556
model: str | None,
5657
backend_kwargs: dict[str, Any] | None,
5758
) -> Backend:
5859
backend = (
59-
Backend.create(
60-
backend, target=target, model=model, **(backend_kwargs or {})
61-
)
60+
Backend.create(backend, target=target, model=model, **(backend_kwargs or {}))
6261
if not isinstance(backend, Backend)
6362
else backend
6463
)
@@ -95,18 +94,19 @@ async def resolve_profile(
9594
)
9695
return profile
9796

97+
9898
async def resolve_output_formats(
99-
output_formats: OutputFormatType,
99+
output_formats: OutputFormatT,
100100
output_path: str | Path | None,
101101
) -> dict[str, GenerativeBenchmarkerOutput]:
102-
output_formats = GenerativeBenchmarkerOutput.resolve(
102+
return GenerativeBenchmarkerOutput.resolve(
103103
output_formats=(output_formats or {}), output_path=output_path
104104
)
105-
return output_formats
105+
106106

107107
async def finalize_outputs(
108108
report: GenerativeBenchmarksReport,
109-
resolved_output_formats: dict[str, GenerativeBenchmarkerOutput]
109+
resolved_output_formats: dict[str, GenerativeBenchmarkerOutput],
110110
):
111111
output_format_results = {}
112112
for key, output in resolved_output_formats.items():
@@ -122,7 +122,7 @@ async def finalize_outputs(
122122
@enable_scenarios
123123
async def benchmark_generative_text( # noqa: C901
124124
target: str,
125-
data: DataInputType,
125+
data: DataInputT,
126126
profile: StrategyType | ProfileType | Profile,
127127
rate: list[float] | None = None,
128128
random_seed: int = 42,
@@ -131,18 +131,18 @@ async def benchmark_generative_text( # noqa: C901
131131
backend_kwargs: dict[str, Any] | None = None,
132132
model: str | None = None,
133133
# Data configuration
134-
processor: ProcessorInputType | None = None,
134+
processor: ProcessorInputT | None = None,
135135
processor_args: dict[str, Any] | None = None,
136136
data_args: dict[str, Any] | None = None,
137137
data_sampler: Literal["random"] | None = None,
138138
# Output configuration
139139
output_path: str | Path | None = _CURRENT_WORKING_DIR,
140-
output_formats: OutputFormatType = ("console", "json", "html", "csv"),
140+
output_formats: OutputFormatT = ("console", "json", "html", "csv"),
141141
# Updates configuration
142-
progress: ProgressInputType | None = None,
142+
progress: ProgressInputT | None = None,
143143
print_updates: bool = False,
144144
# Aggregators configuration
145-
add_aggregators: AggregatorInputType | None = None,
145+
add_aggregators: AggregatorInputT | None = None,
146146
warmup: float | None = None,
147147
cooldown: float | None = None,
148148
request_samples: int | None = 20,
@@ -259,7 +259,9 @@ async def benchmark_generative_text( # noqa: C901
259259
)
260260

261261
with console.print_update_step(title="Resolving output formats") as console_step:
262-
resolved_output_formats = await resolve_output_formats(output_formats, output_path)
262+
resolved_output_formats = await resolve_output_formats(
263+
output_formats, output_path
264+
)
263265
console_step.finish(
264266
title="Output formats resolved",
265267
details={key: str(val) for key, val in resolved_output_formats.items()},
@@ -314,7 +316,7 @@ async def benchmark_generative_text( # noqa: C901
314316
async def reimport_benchmarks_report(
315317
file: Path,
316318
output_path: Path | None,
317-
output_formats: OutputFormatType = ("console", "json", "html", "csv"),
319+
output_formats: OutputFormatT = ("console", "json", "html", "csv"),
318320
) -> tuple[GenerativeBenchmarksReport, dict[str, Any]]:
319321
"""
320322
The command-line entry point for re-importing and displaying an
@@ -326,10 +328,15 @@ async def reimport_benchmarks_report(
326328
title=f"Loading benchmarks from {file}"
327329
) as console_step:
328330
report = GenerativeBenchmarksReport.load_file(file)
329-
console_step.finish(f"Import of old benchmarks complete; loaded {len(report.benchmarks)} benchmark(s)")
331+
console_step.finish(
332+
"Import of old benchmarks complete;"
333+
f" loaded {len(report.benchmarks)} benchmark(s)"
334+
)
330335

331336
with console.print_update_step(title="Resolving output formats") as console_step:
332-
resolved_output_formats = await resolve_output_formats(output_formats, output_path)
337+
resolved_output_formats = await resolve_output_formats(
338+
output_formats, output_path
339+
)
333340
console_step.finish(
334341
title="Output formats resolved",
335342
details={key: str(val) for key, val in resolved_output_formats.items()},

src/guidellm/benchmark/scenario.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,18 @@
99
import yaml
1010
from loguru import logger
1111
from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt, SkipValidation
12-
from transformers.tokenization_utils_base import ( # type: ignore[import]
13-
PreTrainedTokenizerBase,
14-
)
1512

1613
from guidellm.backends import Backend, BackendType
1714
from guidellm.benchmark.profile import Profile, ProfileType
18-
from guidellm.benchmark.types import DataInputType, ProcessorInputType, AggregatorInputType
15+
from guidellm.benchmark.types import AggregatorInputT, DataInputT, ProcessorInputT
1916
from guidellm.scheduler import StrategyType
2017
from guidellm.utils import StandardBaseModel
2118

22-
__ALL__ = [
23-
"Scenario",
19+
__all__ = [
2420
"GenerativeTextScenario",
25-
"get_builtin_scenarios",
21+
"Scenario",
2622
"enable_scenarios",
23+
"get_builtin_scenarios",
2724
]
2825

2926
SCENARIO_DIR = Path(__file__).parent / "scenarios/"
@@ -111,7 +108,7 @@ class Config:
111108
arbitrary_types_allowed = True
112109

113110
data: Annotated[
114-
DataInputType,
111+
DataInputT,
115112
# BUG: See https://github.com/pydantic/pydantic/issues/9541
116113
SkipValidation,
117114
]
@@ -125,12 +122,12 @@ class Config:
125122
backend_kwargs: dict[str, Any] | None = None
126123
model: str | None = None
127124
# Data configuration
128-
processor: ProcessorInputType | None = None
125+
processor: ProcessorInputT | None = None
129126
processor_args: dict[str, Any] | None = None
130127
data_args: dict[str, Any] | None = None
131128
data_sampler: Literal["random"] | None = None
132129
# Aggregators configuration
133-
add_aggregators: AggregatorInputType | None = None
130+
add_aggregators: AggregatorInputT | None = None
134131
warmup: Annotated[float | None, Field(gt=0, le=1)] = None
135132
cooldown: Annotated[float | None, Field(gt=0, le=1)] = None
136133
request_samples: PositiveInt | None = 20

src/guidellm/benchmark/types.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,59 @@
11
from __future__ import annotations
2+
23
from collections.abc import Iterable
3-
from typing import Any
44
from pathlib import Path
5-
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
6-
7-
from guidellm.benchmark.output import (
8-
GenerativeBenchmarkerOutput,
9-
)
5+
from typing import Any, TypeAliasType
106

7+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
118
from transformers import ( # type: ignore[import]
129
PreTrainedTokenizerBase,
1310
)
1411

15-
from guidellm.benchmark.progress import BenchmarkerProgress
16-
1712
from guidellm.benchmark.aggregator import (
1813
Aggregator,
1914
CompilableAggregator,
2015
)
16+
from guidellm.benchmark.output import (
17+
GenerativeBenchmarkerOutput,
18+
)
19+
from guidellm.benchmark.progress import BenchmarkerProgress
20+
21+
__all__ = [
22+
"AggregatorInputT",
23+
"DataInputT",
24+
"OutputFormatT",
25+
"ProcessorInputT",
26+
"ProgressInputT",
27+
]
2128

2229

23-
DataInputType = (
30+
DataInputT = TypeAliasType(
31+
"DataInputT",
2432
Iterable[str]
2533
| Iterable[dict[str, Any]]
2634
| Dataset
2735
| DatasetDict
2836
| IterableDataset
2937
| IterableDatasetDict
3038
| str
31-
| Path
39+
| Path,
3240
)
3341

34-
OutputFormatType = (
42+
OutputFormatT = TypeAliasType(
43+
"OutputFormatT",
3544
tuple[str, ...]
3645
| list[str]
3746
| dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput]
38-
| None
47+
| None,
3948
)
4049

41-
ProcessorInputType = str | Path | PreTrainedTokenizerBase
50+
ProcessorInputT = TypeAliasType("ProcessorInputT", str | Path | PreTrainedTokenizerBase)
4251

43-
ProgressInputType = tuple[str, ...] | list[str] | list[BenchmarkerProgress]
52+
ProgressInputT = TypeAliasType(
53+
"ProgressInputT", tuple[str, ...] | list[str] | list[BenchmarkerProgress]
54+
)
4455

45-
AggregatorInputType = dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator]
56+
AggregatorInputT = TypeAliasType(
57+
"AggregatorInputT",
58+
dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator],
59+
)

0 commit comments

Comments
 (0)