Skip to content

Commit ed95373

Browse files
committed
Add advenced shared prefix support
1 parent da29a71 commit ed95373

File tree

1 file changed

+54
-8
lines changed

1 file changed

+54
-8
lines changed

src/guidellm/dataset/synthetic.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import random
3-
from collections.abc import Iterable, Iterator
3+
from collections.abc import Iterable, Iterator, Sequence
44
from itertools import cycle
55
from pathlib import Path
66
from typing import Any, Literal, Optional, Union
@@ -25,12 +25,29 @@
2525
]
2626

2727

28-
class SyntheticDatasetConfig(BaseModel):
28+
class PrefixBucketConfig(BaseModel):
29+
bucket_weight: int = Field(
30+
description="Weight of this bucket in the overall distribution.",
31+
gt=0,
32+
default=100,
33+
)
34+
prefix_count: int = Field(
35+
description="The number of unique prefixs to generate for this bucket.",
36+
gt=1,
37+
default=1,
38+
)
2939
prefix_tokens: int = Field(
30-
description="The number of shared prefix tokens to prepend to each prompt.",
40+
description="The number of prefix tokens per-prompt for this bucket.",
3141
ge=0,
3242
default=0,
3343
)
44+
45+
46+
class SyntheticDatasetConfig(BaseModel):
47+
prefix_buckets: Optional[list[PrefixBucketConfig]] = Field(
48+
description="Buckets for the prefix tokens distribution.",
49+
default=None,
50+
)
3451
prompt_tokens: int = Field(
3552
description="The average number of text tokens generated for prompts.",
3653
gt=0,
@@ -169,17 +186,16 @@ def __iter__(
169186
)
170187
# ensure diff distribution from output tokens
171188
rand = random.Random(self.random_seed + 2) # noqa: S311
189+
shared_prefix_iter = iter(self._create_prefixes(rand))
172190
unique_prefix_iter = cycle(self.processor.get_vocab().values())
173191

174-
prefix_index = rand.randint(0, len(self.text_creator.words))
175-
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
176-
177192
for _, prompt_tokens, output_tokens in zip(
178193
range(self.config.samples),
179194
prompt_tokens_sampler,
180195
output_tokens_sampler,
181196
):
182-
start_index = rand.randint(0, len(self.text_creator.words))
197+
start_index = self._rand_start_index(rand)
198+
prefix_tokens = next(shared_prefix_iter, [])
183199
prompt_text = self.processor.decode(
184200
prefix_tokens
185201
+ self._create_prompt(
@@ -189,10 +205,40 @@ def __iter__(
189205
)
190206
yield {
191207
"prompt": prompt_text,
192-
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
208+
"prompt_tokens_count": len(prefix_tokens) + prompt_tokens,
193209
"output_tokens_count": output_tokens,
194210
}
195211

212+
def _rand_start_index(self, rand: random.Random) -> int:
213+
"""Generate a random start index for text generation."""
214+
return rand.randint(0, len(self.text_creator.words) - 1)
215+
216+
def _create_prefixes(self, rand: random.Random) -> Sequence[list[int]]:
217+
"""Create an iterator for shared prefix tokens."""
218+
buckets = self.config.prefix_buckets
219+
220+
if not buckets:
221+
return []
222+
223+
total_weight = sum(bucket.bucket_weight for bucket in buckets)
224+
if total_weight <= 0:
225+
raise ValueError("Total weight of prefix buckets must be greater than 0.")
226+
227+
prompts = []
228+
for bucket in buckets:
229+
for _ in range(bucket.prefix_count):
230+
start_index = self._rand_start_index(rand)
231+
prompt_tokens = self._create_prompt(bucket.prefix_tokens, start_index)
232+
sample_percent = (
233+
bucket.bucket_weight / bucket.prefix_count / total_weight
234+
)
235+
sample_count = sample_percent * self.config.samples
236+
for _ in range(int(round(sample_count))):
237+
prompts.append(prompt_tokens)
238+
239+
rand.shuffle(prompts)
240+
return prompts
241+
196242
def _create_prompt(
197243
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
198244
) -> list[int]:

0 commit comments

Comments
 (0)