Skip to content

Commit 730eeb1

Browse files
committed
Initial state for datasets rework to enable multimodal and more complicated combinations
1 parent 6d31244 commit 730eeb1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2949
-2084
lines changed

src/guidellm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import logging
88
import os
99

10+
from datasets.utils.logging import disable_progress_bar
11+
1012
with (
1113
open(os.devnull, "w") as devnull, # noqa: PTH123
1214
contextlib.redirect_stderr(devnull),
@@ -19,6 +21,7 @@
1921
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Silence warnings for tokenizers
2022
hf_logging.set_verbosity_error()
2123
logging.getLogger("transformers").setLevel(logging.ERROR)
24+
disable_progress_bar()
2225

2326
from .logger import configure_logger, logger
2427
from .settings import (

src/guidellm/__main__.py

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@
5656
from guidellm.benchmark.scenario import (
5757
GenerativeTextScenario,
5858
)
59+
from guidellm.data import (
60+
GenerativeDatasetArgs,
61+
GenerativeRequestFormatter,
62+
GenerativeRequestType,
63+
)
5964
from guidellm.mock_server import MockServer, MockServerConfig
6065
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
6166
from guidellm.scheduler import StrategyType
@@ -143,6 +148,7 @@ def benchmark():
143148
@click.option(
144149
"--data",
145150
type=str,
151+
multiple=True,
146152
help=(
147153
"The HuggingFace dataset ID, a path to a HuggingFace dataset, "
148154
"a path to a data file csv, json, jsonl, or txt, "
@@ -197,9 +203,7 @@ def benchmark():
197203
default=None,
198204
help=(
199205
"A JSON string containing any arguments to pass to the backend as a "
200-
"dict with **kwargs. Headers can be removed by setting their value to "
201-
"null. For example: "
202-
"""'{"headers": {"Authorization": null, "Custom-Header": "Custom-Value"}}'"""
206+
"dict with **kwargs."
203207
),
204208
)
205209
@click.option(
@@ -234,19 +238,72 @@ def benchmark():
234238
@click.option(
235239
"--data-args",
236240
default=None,
237-
callback=cli_tools.parse_json,
241+
callback=(
242+
lambda _ctx, _param, value: [
243+
GenerativeDatasetArgs.model_validate_json(val)
244+
if val
245+
else GenerativeDatasetArgs()
246+
for val in value
247+
]
248+
if value
249+
else None
250+
),
238251
help=(
239252
"A JSON string containing any arguments to pass to the dataset creation "
240253
"as a dict with **kwargs."
241254
),
242255
)
256+
@click.option(
257+
"--data-samples",
258+
default=-1,
259+
type=int,
260+
help=(
261+
"The number of samples to use from the dataset. If -1 (default), will use all "
262+
"samples in the dataset."
263+
),
264+
)
243265
@click.option(
244266
"--data-sampler",
245267
default=None,
246-
type=click.Choice(["random"]),
268+
type=click.Choice(["shuffle"]),
269+
help="The data sampler type to use.",
270+
)
271+
@click.option(
272+
"--data-request-type",
273+
default="text_completions",
274+
type=str,
275+
help=(
276+
"The type of request to create for each data sample. "
277+
f"For example, {list(get_literal_vals(GenerativeRequestType))}."
278+
),
279+
)
280+
@click.option(
281+
"--data-request-template",
282+
default=None,
283+
help=(
284+
"A Jinja2 template string or path to a Jinja2 template file to use for "
285+
"creating requests from the data samples. If not provided, will use a "
286+
"default template based on the request type."
287+
),
288+
)
289+
@click.option(
290+
"--data-request-extras",
291+
default=None,
292+
callback=cli_tools.parse_json,
293+
help=("A JSON string of extra data to include with each data request."),
294+
)
295+
@click.option(
296+
"--data-request-nonstreaming",
297+
is_flag=True,
298+
help="Set this flag to disable streaming for the data requests.",
299+
)
300+
@click.option(
301+
"--dataloader_kwargs",
302+
default=None,
303+
callback=cli_tools.parse_json,
247304
help=(
248-
"The data sampler type to use. 'random' will add a random shuffle on the data. "
249-
"Defaults to None"
305+
"A JSON string containing any arguments to pass to the dataloader constructor "
306+
"as a dict with **kwargs."
250307
),
251308
)
252309
# Output configuration
@@ -387,7 +444,13 @@ def run(
387444
processor,
388445
processor_args,
389446
data_args,
447+
data_samples,
390448
data_sampler,
449+
data_request_type,
450+
data_request_template,
451+
data_request_extras,
452+
data_request_nonstreaming,
453+
dataloader_kwargs,
391454
# Output configuration
392455
output_path,
393456
output_formats,
@@ -420,7 +483,8 @@ def run(
420483
asyncio.run(
421484
benchmark_generative_text(
422485
target=target,
423-
data=data,
486+
data=list(data),
487+
# Benchmark configuration
424488
profile=profile,
425489
rate=rate,
426490
random_seed=random_seed,
@@ -432,7 +496,22 @@ def run(
432496
processor=processor,
433497
processor_args=processor_args,
434498
data_args=data_args,
435-
data_sampler=data_sampler,
499+
data_samples=data_samples,
500+
data_column_mapper=None, # use default
501+
data_request_formatter=GenerativeRequestFormatter(
502+
request_type=data_request_type,
503+
request_template=data_request_template,
504+
request_extras=data_request_extras,
505+
request_defaults=(
506+
{} # disable defaults if non-streaming
507+
if data_request_nonstreaming
508+
else None
509+
),
510+
),
511+
data_preprocessors=None, # no preprocessors through CLI for now
512+
dataloader_sampler=data_sampler,
513+
dataloader_collate_fn=None, # use default
514+
dataloader_kwargs=dataloader_kwargs,
436515
# Output configuration
437516
output_path=output_path,
438517
output_formats=[

src/guidellm/backends/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
GenerationRequest,
1414
GenerationRequestTimings,
1515
GenerationResponse,
16+
GenerationTokenStats,
1617
)
1718
from .openai import OpenAIHTTPBackend
1819

@@ -22,5 +23,6 @@
2223
"GenerationRequest",
2324
"GenerationRequestTimings",
2425
"GenerationResponse",
26+
"GenerationTokenStats",
2527
"OpenAIHTTPBackend",
2628
]

src/guidellm/backends/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,5 +115,6 @@ def requests_limit(self) -> int | None:
115115
async def default_model(self) -> str | None:
116116
"""
117117
:return: The default model name or identifier for generation requests.
118+
None if no default model is available.
118119
"""
119120
...

src/guidellm/backends/objects.py

Lines changed: 43 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -6,62 +6,51 @@
66
implementations.
77
"""
88

9-
import uuid
10-
from typing import Any, Literal, Optional
9+
from __future__ import annotations
10+
11+
from typing import Literal
1112

1213
from pydantic import Field
1314

15+
from guidellm.data import (
16+
GenerationRequest,
17+
GenerationRequestArguments,
18+
GenerationRequestTimings,
19+
)
1420
from guidellm.scheduler import (
15-
MeasuredRequestTimings,
1621
SchedulerMessagingPydanticRegistry,
1722
)
1823
from guidellm.utils import StandardBaseModel
1924

2025
__all__ = [
2126
"GenerationRequest",
27+
"GenerationRequestArguments",
2228
"GenerationRequestTimings",
2329
"GenerationResponse",
30+
"GenerationTokenStats",
2431
]
2532

2633

2734
@SchedulerMessagingPydanticRegistry.register()
28-
class GenerationRequest(StandardBaseModel):
29-
"""Request model for backend generation operations."""
35+
class GenerationTokenStats(StandardBaseModel):
36+
"""Token statistics for generation requests and responses."""
3037

31-
request_id: str = Field(
32-
default_factory=lambda: str(uuid.uuid4()),
33-
description="Unique identifier for the request.",
34-
)
35-
request_type: Literal["text_completions", "chat_completions"] = Field(
36-
default="text_completions",
37-
description=(
38-
"Type of request. 'text_completions' uses backend.text_completions(), "
39-
"'chat_completions' uses backend.chat_completions()."
40-
),
41-
)
42-
content: Any = Field(
43-
description=(
44-
"Request content. For text_completions: string or list of strings. "
45-
"For chat_completions: string, list of messages, or raw content "
46-
"(set raw_content=True in params)."
47-
)
48-
)
49-
params: dict[str, Any] = Field(
50-
default_factory=dict,
51-
description=(
52-
"Additional parameters passed to backend methods. "
53-
"Common: max_tokens, temperature, stream."
54-
),
38+
request: int | None = Field(
39+
default=None, description="Number of tokens in the original request."
5540
)
56-
stats: dict[Literal["prompt_tokens"], int] = Field(
57-
default_factory=dict,
58-
description="Request statistics including prompt token count.",
59-
)
60-
constraints: dict[Literal["output_tokens"], int] = Field(
61-
default_factory=dict,
62-
description="Request constraints such as maximum output tokens.",
41+
response: int | None = Field(
42+
default=None, description="Number of tokens in the generated response."
6343
)
6444

45+
def value(
46+
self, preference: Literal["request", "response"] | None = None
47+
) -> int | None:
48+
if preference == "request":
49+
return self.request
50+
if preference == "response":
51+
return self.response
52+
return self.response if self.response is not None else self.request
53+
6554

6655
@SchedulerMessagingPydanticRegistry.register()
6756
class GenerationResponse(StandardBaseModel):
@@ -70,87 +59,32 @@ class GenerationResponse(StandardBaseModel):
7059
request_id: str = Field(
7160
description="Unique identifier matching the original GenerationRequest."
7261
)
73-
request_args: dict[str, Any] = Field(
62+
request_args: GenerationRequestArguments = Field(
7463
description="Arguments passed to the backend for this request."
7564
)
76-
value: Optional[str] = Field(
65+
text: str | None = Field(
7766
default=None,
78-
description="Complete generated text content. None for streaming responses.",
79-
)
80-
delta: Optional[str] = Field(
81-
default=None, description="Incremental text content for streaming responses."
67+
description="The generated response text.",
8268
)
8369
iterations: int = Field(
8470
default=0, description="Number of generation iterations completed."
8571
)
86-
request_prompt_tokens: Optional[int] = Field(
87-
default=None, description="Token count from the original request prompt."
88-
)
89-
request_output_tokens: Optional[int] = Field(
90-
default=None,
91-
description="Expected output token count from the original request.",
92-
)
93-
response_prompt_tokens: Optional[int] = Field(
94-
default=None, description="Actual prompt token count reported by the backend."
72+
73+
prompt_stats: GenerationTokenStats = Field(
74+
default_factory=GenerationTokenStats,
75+
description="Token statistics from the prompt.",
9576
)
96-
response_output_tokens: Optional[int] = Field(
97-
default=None, description="Actual output token count reported by the backend."
77+
output_stats: GenerationTokenStats = Field(
78+
default_factory=GenerationTokenStats,
79+
description="Token statistics from the generated output.",
9880
)
9981

100-
@property
101-
def prompt_tokens(self) -> Optional[int]:
102-
"""
103-
:return: The number of prompt tokens used in the request
104-
(response_prompt_tokens if available, otherwise request_prompt_tokens).
105-
"""
106-
return self.response_prompt_tokens or self.request_prompt_tokens
107-
108-
@property
109-
def output_tokens(self) -> Optional[int]:
110-
"""
111-
:return: The number of output tokens generated in the response
112-
(response_output_tokens if available, otherwise request_output_tokens).
113-
"""
114-
return self.response_output_tokens or self.request_output_tokens
115-
116-
@property
117-
def total_tokens(self) -> Optional[int]:
118-
"""
119-
:return: The total number of tokens used in the request and response.
120-
Sum of prompt_tokens and output_tokens.
121-
"""
122-
if self.prompt_tokens is None or self.output_tokens is None:
123-
return None
124-
return self.prompt_tokens + self.output_tokens
125-
126-
def preferred_prompt_tokens(
127-
self, preferred_source: Literal["request", "response"]
128-
) -> Optional[int]:
129-
if preferred_source == "request":
130-
return self.request_prompt_tokens or self.response_prompt_tokens
131-
else:
132-
return self.response_prompt_tokens or self.request_prompt_tokens
133-
134-
def preferred_output_tokens(
135-
self, preferred_source: Literal["request", "response"]
136-
) -> Optional[int]:
137-
if preferred_source == "request":
138-
return self.request_output_tokens or self.response_output_tokens
139-
else:
140-
return self.response_output_tokens or self.request_output_tokens
141-
142-
143-
@SchedulerMessagingPydanticRegistry.register()
144-
@MeasuredRequestTimings.register("generation_request_timings")
145-
class GenerationRequestTimings(MeasuredRequestTimings):
146-
"""Timing model for tracking generation request lifecycle events."""
82+
def total_tokens(
83+
self, preference: Literal["request", "response"] | None = None
84+
) -> int | None:
85+
prompt_tokens = self.prompt_stats.value(preference=preference)
86+
output_tokens = self.output_stats.value(preference=preference)
14787

148-
timings_type: Literal["generation_request_timings"] = "generation_request_timings"
149-
first_iteration: Optional[float] = Field(
150-
default=None,
151-
description="Unix timestamp when the first generation iteration began.",
152-
)
153-
last_iteration: Optional[float] = Field(
154-
default=None,
155-
description="Unix timestamp when the last generation iteration completed.",
156-
)
88+
if prompt_tokens is None and output_tokens is None:
89+
return None
90+
return (prompt_tokens or 0) + (output_tokens or 0)

0 commit comments

Comments
 (0)