11import json
22import random
3- from collections .abc import Iterable , Iterator
3+ from collections .abc import Iterable , Iterator , Sequence
44from itertools import cycle
55from pathlib import Path
66from typing import Any , Literal , Optional , Union
1919from guidellm .utils import EndlessTextCreator , IntegerRangeSampler , check_load_processor
2020
2121__all__ = [
22+ "PrefixBucketConfig" ,
2223 "SyntheticDatasetConfig" ,
2324 "SyntheticDatasetCreator" ,
2425 "SyntheticTextItemsGenerator" ,
2526]
2627
2728
28- class SyntheticDatasetConfig (BaseModel ):
29+ class PrefixBucketConfig (BaseModel ):
30+ bucket_weight : int = Field (
31+ description = "Weight of this bucket in the overall distribution." ,
32+ gt = 0 ,
33+ default = 100 ,
34+ )
35+ prefix_count : int = Field (
36+ description = "The number of unique prefixs to generate for this bucket." ,
37+ ge = 1 ,
38+ default = 1 ,
39+ )
2940 prefix_tokens : int = Field (
30- description = "The number of shared prefix tokens to prepend to each prompt ." ,
41+ description = "The number of prefix tokens per-prompt for this bucket ." ,
3142 ge = 0 ,
3243 default = 0 ,
3344 )
45+
46+
47+ class SyntheticDatasetConfig (BaseModel ):
48+ prefix_buckets : Optional [list [PrefixBucketConfig ]] = Field (
49+ description = "Buckets for the prefix tokens distribution." ,
50+ default = None ,
51+ )
3452 prompt_tokens : int = Field (
3553 description = "The average number of text tokens generated for prompts." ,
3654 gt = 0 ,
@@ -169,17 +187,16 @@ def __iter__(
169187 )
170188 # ensure diff distribution from output tokens
171189 rand = random .Random (self .random_seed + 2 ) # noqa: S311
190+ shared_prefix_iter = iter (self ._create_prefixes (rand ))
172191 unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
173192
174- prefix_index = rand .randint (0 , len (self .text_creator .words ))
175- prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
176-
177193 for _ , prompt_tokens , output_tokens in zip (
178194 range (self .config .samples ),
179195 prompt_tokens_sampler ,
180196 output_tokens_sampler ,
181197 ):
182- start_index = rand .randint (0 , len (self .text_creator .words ))
198+ start_index = self ._rand_start_index (rand )
199+ prefix_tokens = next (shared_prefix_iter , [])
183200 prompt_text = self .processor .decode (
184201 prefix_tokens
185202 + self ._create_prompt (
@@ -189,10 +206,40 @@ def __iter__(
189206 )
190207 yield {
191208 "prompt" : prompt_text ,
192- "prompt_tokens_count" : self . config . prefix_tokens + prompt_tokens ,
209+ "prompt_tokens_count" : len ( prefix_tokens ) + prompt_tokens ,
193210 "output_tokens_count" : output_tokens ,
194211 }
195212
213+ def _rand_start_index (self , rand : random .Random ) -> int :
214+ """Generate a random start index for text generation."""
215+ return rand .randint (0 , len (self .text_creator .words ) - 1 )
216+
217+ def _create_prefixes (self , rand : random .Random ) -> Sequence [list [int ]]:
218+ """Create an iterator for shared prefix tokens."""
219+ buckets = self .config .prefix_buckets
220+
221+ if not buckets :
222+ return []
223+
224+ total_weight = sum (bucket .bucket_weight for bucket in buckets )
225+ if total_weight <= 0 :
226+ raise ValueError ("Total weight of prefix buckets must be greater than 0." )
227+
228+ prompts = []
229+ for bucket in buckets :
230+ for _ in range (bucket .prefix_count ):
231+ start_index = self ._rand_start_index (rand )
232+ prompt_tokens = self ._create_prompt (bucket .prefix_tokens , start_index )
233+ sample_percent = (
234+ bucket .bucket_weight / bucket .prefix_count / total_weight
235+ )
236+ sample_count = sample_percent * self .config .samples
237+ for _ in range (int (round (sample_count ))):
238+ prompts .append (prompt_tokens )
239+
240+ rand .shuffle (prompts )
241+ return prompts
242+
196243 def _create_prompt (
197244 self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
198245 ) -> list [int ]:
0 commit comments