Skip to content

Commit 81d8ed3

Browse files
committed
Add turns support to synthetic dataset
1 parent cc9d26f commit 81d8ed3

File tree

1 file changed

+52
-25
lines changed

1 file changed

+52
-25
lines changed

src/guidellm/dataset/synthetic.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import random
33
from collections.abc import Iterable, Iterator
44
from pathlib import Path
5-
from typing import Any, Literal, Optional, Union
5+
from typing import Any, Optional, TypedDict, Union
66

77
import yaml
88
from datasets import (
@@ -63,6 +63,26 @@ class SyntheticDatasetConfig(BaseModel):
6363
gt=0,
6464
default=None,
6565
)
66+
turns: int = Field(
67+
description="The number of turns in the conversation.",
68+
gt=0,
69+
default=1,
70+
)
71+
turns_stdev: Optional[int] = Field(
72+
description="The standard deviation of the number of turns.",
73+
gt=0,
74+
default=None,
75+
)
76+
turns_min: Optional[int] = Field(
77+
description="The minimum number of turns in the conversation.",
78+
gt=0,
79+
default=None,
80+
)
81+
turns_max: Optional[int] = Field(
82+
description="The maximum number of turns in the conversation.",
83+
gt=0,
84+
default=None,
85+
)
6686
samples: int = Field(
6787
description="The number of samples to generate for the dataset.",
6888
gt=0,
@@ -118,14 +138,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
118138
return SyntheticDatasetConfig(**config_dict)
119139

120140

121-
class SyntheticTextItemsGenerator(
122-
Iterable[
123-
dict[
124-
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
125-
Union[str, int],
126-
]
127-
]
128-
):
141+
class SyntheticDatasetRow(TypedDict):
142+
prompt: list[str]
143+
prompt_tokens_count: list[int]
144+
output_tokens_count: list[int]
145+
146+
147+
class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]):
129148
def __init__(
130149
self,
131150
config: SyntheticDatasetConfig,
@@ -141,12 +160,7 @@ def __init__(
141160

142161
def __iter__(
143162
self,
144-
) -> Iterator[
145-
dict[
146-
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
147-
Union[str, int],
148-
]
149-
]:
163+
) -> Iterator[SyntheticDatasetRow]:
150164
prompt_tokens_sampler = IntegerRangeSampler(
151165
average=self.config.prompt_tokens,
152166
variance=self.config.prompt_tokens_stdev,
@@ -161,20 +175,33 @@ def __iter__(
161175
max_value=self.config.output_tokens_max,
162176
random_seed=self.random_seed + 1, # ensure diff dist from prompts
163177
)
178+
turns_sampler = IntegerRangeSampler(
179+
average=self.config.turns,
180+
variance=self.config.turns_stdev,
181+
min_value=self.config.turns_min,
182+
max_value=self.config.turns_max,
183+
random_seed=self.random_seed + 7, # ensure diff dist
184+
)
164185
# ensure diff distribution from output tokens
165186
rand = random.Random(self.random_seed + 2) # noqa: S311
166187

167-
for _, prompt_tokens, output_tokens in zip(
168-
range(self.config.samples),
169-
prompt_tokens_sampler,
170-
output_tokens_sampler,
171-
):
172-
start_index = rand.randint(0, len(self.text_creator.words))
173-
yield {
174-
"prompt": self._create_prompt(prompt_tokens, start_index),
175-
"prompt_tokens_count": prompt_tokens,
176-
"output_tokens_count": output_tokens,
188+
for _, turns in zip(range(self.config.samples), turns_sampler):
189+
row: SyntheticDatasetRow = {
190+
"prompt": [],
191+
"prompt_tokens_count": [],
192+
"output_tokens_count": [],
177193
}
194+
for _, prompt_tokens, output_tokens in zip(
195+
range(turns),
196+
prompt_tokens_sampler,
197+
output_tokens_sampler,
198+
):
199+
start_index = rand.randint(0, len(self.text_creator.words))
200+
row["prompt"].append(self._create_prompt(prompt_tokens, start_index))
201+
row["prompt_tokens_count"].append(prompt_tokens)
202+
row["output_tokens_count"].append(output_tokens)
203+
204+
yield row
178205

179206
def _create_prompt(self, prompt_tokens: int, start_index: int) -> str:
180207
if prompt_tokens <= 0:

0 commit comments

Comments
 (0)