1
1
from __future__ import annotations
2
2
3
- from collections .abc import Iterable
4
3
from pathlib import Path
5
4
from typing import Any , Literal
6
5
7
- from datasets import Dataset , DatasetDict , IterableDataset , IterableDatasetDict
8
- from transformers import ( # type: ignore[import]
9
- PreTrainedTokenizerBase ,
10
- )
11
-
12
6
from guidellm .backends import (
13
7
Backend ,
14
8
BackendType ,
15
9
GenerationRequest ,
16
10
GenerationResponse ,
17
11
)
18
12
from guidellm .benchmark .aggregator import (
19
- Aggregator ,
20
- CompilableAggregator ,
21
13
GenerativeRequestsAggregator ,
22
14
GenerativeStatsProgressAggregator ,
23
15
SchedulerStatsAggregator ,
29
21
GenerativeBenchmarkerOutput ,
30
22
)
31
23
from guidellm .benchmark .profile import Profile , ProfileType
32
- from guidellm .benchmark .progress import (
33
- BenchmarkerProgress ,
34
- BenchmarkerProgressGroup ,
35
- )
24
+ from guidellm .benchmark .progress import BenchmarkerProgressGroup
36
25
from guidellm .benchmark .scenario import enable_scenarios
26
+ from guidellm .benchmark .type import OutputFormatType , DataInputType , ProcessorInputType , ProgressInputType , \
27
+ AggregatorInputType
37
28
from guidellm .request import GenerativeRequestLoader
38
29
from guidellm .scheduler import (
39
30
ConstraintInitializer ,
51
42
_CURRENT_WORKING_DIR = Path .cwd ()
52
43
53
44
54
- # Data types
55
-
56
- DataType = (
57
- Iterable [str ]
58
- | Iterable [dict [str , Any ]]
59
- | Dataset
60
- | DatasetDict
61
- | IterableDataset
62
- | IterableDatasetDict
63
- | str
64
- | Path
65
- )
66
-
67
- OutputFormatType = (
68
- tuple [str , ...]
69
- | list [str ]
70
- | dict [str , str | dict [str , Any ] | GenerativeBenchmarkerOutput ]
71
- | None
72
- )
73
-
74
-
75
45
# Helper functions
76
46
77
47
async def initialize_backend (
@@ -147,7 +117,7 @@ async def finalize_outputs(
147
117
@enable_scenarios
148
118
async def benchmark_generative_text ( # noqa: C901
149
119
target : str ,
150
- data : DataType ,
120
+ data : DataInputType ,
151
121
profile : StrategyType | ProfileType | Profile ,
152
122
rate : list [float ] | None = None ,
153
123
random_seed : int = 42 ,
@@ -156,20 +126,18 @@ async def benchmark_generative_text( # noqa: C901
156
126
backend_kwargs : dict [str , Any ] | None = None ,
157
127
model : str | None = None ,
158
128
# Data configuration
159
- processor : str | Path | PreTrainedTokenizerBase | None = None ,
129
+ processor : ProcessorInputType | None = None ,
160
130
processor_args : dict [str , Any ] | None = None ,
161
131
data_args : dict [str , Any ] | None = None ,
162
132
data_sampler : Literal ["random" ] | None = None ,
163
133
# Output configuration
164
134
output_path : str | Path | None = _CURRENT_WORKING_DIR ,
165
135
output_formats : OutputFormatType = ("console" , "json" , "html" , "csv" ),
166
136
# Updates configuration
167
- progress : tuple [ str , ...] | list [ str ] | list [ BenchmarkerProgress ] | None = None ,
137
+ progress : ProgressInputType | None = None ,
168
138
print_updates : bool = False ,
169
139
# Aggregators configuration
170
- add_aggregators : (
171
- dict [str , str | dict [str , Any ] | Aggregator | CompilableAggregator ] | None
172
- ) = None ,
140
+ add_aggregators : AggregatorInputType | None = None ,
173
141
warmup : float | None = None ,
174
142
cooldown : float | None = None ,
175
143
request_samples : int | None = 20 ,
0 commit comments