Skip to content

Commit bb74a55

Browse files
committed
Share type alises between entrypoints and scenario
1 parent 03f9085 commit bb74a55

File tree

3 files changed

+56
-57
lines changed

3 files changed

+56
-57
lines changed

src/guidellm/benchmark/entrypoints.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable
43
from pathlib import Path
54
from typing import Any, Literal
65

7-
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
8-
from transformers import ( # type: ignore[import]
9-
PreTrainedTokenizerBase,
10-
)
11-
126
from guidellm.backends import (
137
Backend,
148
BackendType,
159
GenerationRequest,
1610
GenerationResponse,
1711
)
1812
from guidellm.benchmark.aggregator import (
19-
Aggregator,
20-
CompilableAggregator,
2113
GenerativeRequestsAggregator,
2214
GenerativeStatsProgressAggregator,
2315
SchedulerStatsAggregator,
@@ -29,11 +21,10 @@
2921
GenerativeBenchmarkerOutput,
3022
)
3123
from guidellm.benchmark.profile import Profile, ProfileType
32-
from guidellm.benchmark.progress import (
33-
BenchmarkerProgress,
34-
BenchmarkerProgressGroup,
35-
)
24+
from guidellm.benchmark.progress import BenchmarkerProgressGroup
3625
from guidellm.benchmark.scenario import enable_scenarios
26+
from guidellm.benchmark.type import OutputFormatType, DataInputType, ProcessorInputType, ProgressInputType, \
27+
AggregatorInputType
3728
from guidellm.request import GenerativeRequestLoader
3829
from guidellm.scheduler import (
3930
ConstraintInitializer,
@@ -51,27 +42,6 @@
5142
_CURRENT_WORKING_DIR = Path.cwd()
5243

5344

54-
# Data types
55-
56-
DataType = (
57-
Iterable[str]
58-
| Iterable[dict[str, Any]]
59-
| Dataset
60-
| DatasetDict
61-
| IterableDataset
62-
| IterableDatasetDict
63-
| str
64-
| Path
65-
)
66-
67-
OutputFormatType = (
68-
tuple[str, ...]
69-
| list[str]
70-
| dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput]
71-
| None
72-
)
73-
74-
7545
# Helper functions
7646

7747
async def initialize_backend(
@@ -147,7 +117,7 @@ async def finalize_outputs(
147117
@enable_scenarios
148118
async def benchmark_generative_text( # noqa: C901
149119
target: str,
150-
data: DataType,
120+
data: DataInputType,
151121
profile: StrategyType | ProfileType | Profile,
152122
rate: list[float] | None = None,
153123
random_seed: int = 42,
@@ -156,20 +126,18 @@ async def benchmark_generative_text( # noqa: C901
156126
backend_kwargs: dict[str, Any] | None = None,
157127
model: str | None = None,
158128
# Data configuration
159-
processor: str | Path | PreTrainedTokenizerBase | None = None,
129+
processor: ProcessorInputType | None = None,
160130
processor_args: dict[str, Any] | None = None,
161131
data_args: dict[str, Any] | None = None,
162132
data_sampler: Literal["random"] | None = None,
163133
# Output configuration
164134
output_path: str | Path | None = _CURRENT_WORKING_DIR,
165135
output_formats: OutputFormatType = ("console", "json", "html", "csv"),
166136
# Updates configuration
167-
progress: tuple[str, ...] | list[str] | list[BenchmarkerProgress] | None = None,
137+
progress: ProgressInputType | None = None,
168138
print_updates: bool = False,
169139
# Aggregators configuration
170-
add_aggregators: (
171-
dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator] | None
172-
) = None,
140+
add_aggregators: AggregatorInputType | None = None,
173141
warmup: float | None = None,
174142
cooldown: float | None = None,
175143
request_samples: int | None = 20,

src/guidellm/benchmark/scenario.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,21 @@
11
from __future__ import annotations
22

33
import json
4-
from collections.abc import Iterable
54
from functools import cache, wraps
65
from inspect import Parameter, signature
76
from pathlib import Path
87
from typing import Annotated, Any, Callable, Literal, TypeVar
98

109
import yaml
11-
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1210
from loguru import logger
1311
from pydantic import BeforeValidator, Field, PositiveFloat, PositiveInt, SkipValidation
1412
from transformers.tokenization_utils_base import ( # type: ignore[import]
1513
PreTrainedTokenizerBase,
1614
)
1715

1816
from guidellm.backends import Backend, BackendType
19-
from guidellm.benchmark.aggregator import (
20-
Aggregator,
21-
CompilableAggregator,
22-
)
2317
from guidellm.benchmark.profile import Profile, ProfileType
18+
from guidellm.benchmark.type import DataInputType, ProcessorInputType, AggregatorInputType
2419
from guidellm.scheduler import StrategyType
2520
from guidellm.utils import StandardBaseModel
2621

@@ -116,14 +111,7 @@ class Config:
116111
arbitrary_types_allowed = True
117112

118113
data: Annotated[
119-
Iterable[str]
120-
| Iterable[dict[str, Any]]
121-
| Dataset
122-
| DatasetDict
123-
| IterableDataset
124-
| IterableDatasetDict
125-
| str
126-
| Path,
114+
DataInputType,
127115
# BUG: See https://github.com/pydantic/pydantic/issues/9541
128116
SkipValidation,
129117
]
@@ -137,14 +125,12 @@ class Config:
137125
backend_kwargs: dict[str, Any] | None = None
138126
model: str | None = None
139127
# Data configuration
140-
processor: str | Path | PreTrainedTokenizerBase | None = None
128+
processor: ProcessorInputType | None = None
141129
processor_args: dict[str, Any] | None = None
142130
data_args: dict[str, Any] | None = None
143131
data_sampler: Literal["random"] | None = None
144132
# Aggregators configuration
145-
add_aggregators: (
146-
dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator] | None
147-
) = None
133+
add_aggregators: AggregatorInputType | None = None
148134
warmup: Annotated[float | None, Field(gt=0, le=1)] = None
149135
cooldown: Annotated[float | None, Field(gt=0, le=1)] = None
150136
request_samples: PositiveInt | None = 20

src/guidellm/benchmark/type.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from __future__ import annotations
2+
from collections.abc import Iterable
3+
from typing import Any
4+
from pathlib import Path
5+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
6+
7+
from guidellm.benchmark.output import (
8+
GenerativeBenchmarkerOutput,
9+
)
10+
11+
from transformers import ( # type: ignore[import]
12+
PreTrainedTokenizerBase,
13+
)
14+
15+
from guidellm.benchmark.progress import BenchmarkerProgress
16+
17+
from guidellm.benchmark.aggregator import (
18+
Aggregator,
19+
CompilableAggregator,
20+
)
21+
22+
23+
DataInputType = (
24+
Iterable[str]
25+
| Iterable[dict[str, Any]]
26+
| Dataset
27+
| DatasetDict
28+
| IterableDataset
29+
| IterableDatasetDict
30+
| str
31+
| Path
32+
)
33+
34+
OutputFormatType = (
35+
tuple[str, ...]
36+
| list[str]
37+
| dict[str, str | dict[str, Any] | GenerativeBenchmarkerOutput]
38+
| None
39+
)
40+
41+
ProcessorInputType = str | Path | PreTrainedTokenizerBase
42+
43+
ProgressInputType = tuple[str, ...] | list[str] | list[BenchmarkerProgress]
44+
45+
AggregatorInputType = dict[str, str | dict[str, Any] | Aggregator | CompilableAggregator]

0 commit comments

Comments
 (0)