Skip to content

Commit eade3a2

Browse files
sjmonsonMML-coder
andcommitted
Add fixed prefix option to synthetic data
Signed-off-by: Samuel Monson <[email protected]> Add prefix before decode Signed-off-by: Samuel Monson <[email protected]> Add unique single-token prefix to every request Co-authored-by: Mehul <[email protected]> Co-authored-by: Samuel Monson <[email protected]> Signed-off-by: Samuel Monson <[email protected]>
1 parent cd43b2c commit eade3a2

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

src/guidellm/dataset/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .hf_datasets import HFDatasetsCreator
55
from .in_memory import InMemoryDatasetCreator
66
from .synthetic import (
7+
PrefixBucketConfig,
78
SyntheticDatasetConfig,
89
SyntheticDatasetCreator,
910
SyntheticTextItemsGenerator,
@@ -15,6 +16,7 @@
1516
"FileDatasetCreator",
1617
"HFDatasetsCreator",
1718
"InMemoryDatasetCreator",
19+
"PrefixBucketConfig",
1820
"SyntheticDatasetConfig",
1921
"SyntheticDatasetCreator",
2022
"SyntheticTextItemsGenerator",

src/guidellm/dataset/synthetic.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import random
3-
from collections.abc import Iterable, Iterator
3+
from collections.abc import Iterable, Iterator, Sequence
44
from itertools import cycle
55
from pathlib import Path
66
from typing import Any, Optional, TypedDict, Union
@@ -19,18 +19,36 @@
1919
from guidellm.utils import EndlessTextCreator, IntegerRangeSampler, check_load_processor
2020

2121
__all__ = [
22+
"PrefixBucketConfig",
2223
"SyntheticDatasetConfig",
2324
"SyntheticDatasetCreator",
2425
"SyntheticTextItemsGenerator",
2526
]
2627

2728

28-
class SyntheticDatasetConfig(BaseModel):
29+
class PrefixBucketConfig(BaseModel):
30+
bucket_weight: int = Field(
31+
description="Weight of this bucket in the overall distribution.",
32+
gt=0,
33+
default=100,
34+
)
35+
prefix_count: int = Field(
36+
description="The number of unique prefixs to generate for this bucket.",
37+
ge=1,
38+
default=1,
39+
)
2940
prefix_tokens: int = Field(
30-
description="The number of shared prefix tokens to prepend to each prompt.",
41+
description="The number of prefix tokens per-prompt for this bucket.",
3142
ge=0,
3243
default=0,
3344
)
45+
46+
47+
class SyntheticDatasetConfig(BaseModel):
48+
prefix_buckets: Optional[list[PrefixBucketConfig]] = Field(
49+
description="Buckets for the prefix tokens distribution.",
50+
default=None,
51+
)
3452
prompt_tokens: int = Field(
3553
description="The average number of text tokens generated for prompts.",
3654
gt=0,
@@ -190,11 +208,9 @@ def __iter__(
190208
)
191209
# ensure diff distribution from output tokens
192210
rand = random.Random(self.random_seed + 2) # noqa: S311
211+
shared_prefix_iter = iter(self._create_prefixes(rand))
193212
unique_prefix_iter = cycle(self.processor.get_vocab().values())
194213

195-
prefix_index = rand.randint(0, len(self.text_creator.words))
196-
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
197-
198214
for _, turns in zip(range(self.config.samples), turns_sampler):
199215
row: SyntheticDatasetRow = {
200216
"prompt": [],
@@ -207,6 +223,7 @@ def __iter__(
207223
output_tokens_sampler,
208224
):
209225
start_index = rand.randint(0, len(self.text_creator.words))
226+
prefix_tokens = next(shared_prefix_iter, [])
210227
# Append the prefix tokens only for the first turn
211228
if i == 0:
212229
prompt_text = self.processor.decode(
@@ -217,7 +234,7 @@ def __iter__(
217234
skip_special_tokens=True,
218235
)
219236
row["prompt"].append(prompt_text)
220-
row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens)
237+
row["prompt_tokens_count"].append(len(prefix_tokens) + prompt_tokens)
221238
row["output_tokens_count"].append(output_tokens)
222239
else:
223240
prompt_text = self.processor.decode(
@@ -232,6 +249,36 @@ def __iter__(
232249

233250
yield row
234251

252+
def _rand_start_index(self, rand: random.Random) -> int:
253+
"""Generate a random start index for text generation."""
254+
return rand.randint(0, len(self.text_creator.words) - 1)
255+
256+
def _create_prefixes(self, rand: random.Random) -> Sequence[list[int]]:
257+
"""Create an iterator for shared prefix tokens."""
258+
buckets = self.config.prefix_buckets
259+
260+
if not buckets:
261+
return []
262+
263+
total_weight = sum(bucket.bucket_weight for bucket in buckets)
264+
if total_weight <= 0:
265+
raise ValueError("Total weight of prefix buckets must be greater than 0.")
266+
267+
prompts = []
268+
for bucket in buckets:
269+
for _ in range(bucket.prefix_count):
270+
start_index = self._rand_start_index(rand)
271+
prompt_tokens = self._create_prompt(bucket.prefix_tokens, start_index)
272+
sample_percent = (
273+
bucket.bucket_weight / bucket.prefix_count / total_weight
274+
)
275+
sample_count = sample_percent * self.config.samples
276+
for _ in range(int(round(sample_count))):
277+
prompts.append(prompt_tokens)
278+
279+
rand.shuffle(prompts)
280+
return prompts
281+
235282
def _create_prompt(
236283
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
237284
) -> list[int]:

0 commit comments

Comments
 (0)