11import json
22import random
3- from collections .abc import Iterable , Iterator
3+ from collections .abc import Iterable , Iterator , Sequence
44from itertools import cycle
55from pathlib import Path
66from typing import Any , Literal , Optional , Union
2525]
2626
2727
28- class SyntheticDatasetConfig (BaseModel ):
28+ class PrefixBucketConfig (BaseModel ):
29+ bucket_weight : int = Field (
30+ description = "Weight of this bucket in the overall distribution." ,
31+ gt = 0 ,
32+ default = 100 ,
33+ )
34+ prefix_count : int = Field (
35+ description = "The number of unique prefixs to generate for this bucket." ,
36+ gt = 1 ,
37+ default = 1 ,
38+ )
2939 prefix_tokens : int = Field (
30- description = "The number of shared prefix tokens to prepend to each prompt ." ,
40+ description = "The number of prefix tokens per-prompt for this bucket ." ,
3141 ge = 0 ,
3242 default = 0 ,
3343 )
44+
45+
46+ class SyntheticDatasetConfig (BaseModel ):
47+ prefix_buckets : Optional [list [PrefixBucketConfig ]] = Field (
48+ description = "Buckets for the prefix tokens distribution." ,
49+ default = None ,
50+ )
3451 prompt_tokens : int = Field (
3552 description = "The average number of text tokens generated for prompts." ,
3653 gt = 0 ,
@@ -169,17 +186,16 @@ def __iter__(
169186 )
170187 # ensure diff distribution from output tokens
171188 rand = random .Random (self .random_seed + 2 ) # noqa: S311
189+ shared_prefix_iter = iter (self ._create_prefixes (rand ))
172190 unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
173191
174- prefix_index = rand .randint (0 , len (self .text_creator .words ))
175- prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
176-
177192 for _ , prompt_tokens , output_tokens in zip (
178193 range (self .config .samples ),
179194 prompt_tokens_sampler ,
180195 output_tokens_sampler ,
181196 ):
182- start_index = rand .randint (0 , len (self .text_creator .words ))
197+ start_index = self ._rand_start_index (rand )
198+ prefix_tokens = next (shared_prefix_iter , [])
183199 prompt_text = self .processor .decode (
184200 prefix_tokens
185201 + self ._create_prompt (
@@ -189,10 +205,40 @@ def __iter__(
189205 )
190206 yield {
191207 "prompt" : prompt_text ,
192- "prompt_tokens_count" : self . config . prefix_tokens + prompt_tokens ,
208+ "prompt_tokens_count" : len ( prefix_tokens ) + prompt_tokens ,
193209 "output_tokens_count" : output_tokens ,
194210 }
195211
212+ def _rand_start_index (self , rand : random .Random ) -> int :
213+ """Generate a random start index for text generation."""
214+ return rand .randint (0 , len (self .text_creator .words ) - 1 )
215+
216+ def _create_prefixes (self , rand : random .Random ) -> Sequence [list [int ]]:
217+ """Create an iterator for shared prefix tokens."""
218+ buckets = self .config .prefix_buckets
219+
220+ if not buckets :
221+ return []
222+
223+ total_weight = sum (bucket .bucket_weight for bucket in buckets )
224+ if total_weight <= 0 :
225+ raise ValueError ("Total weight of prefix buckets must be greater than 0." )
226+
227+ prompts = []
228+ for bucket in buckets :
229+ for _ in range (bucket .prefix_count ):
230+ start_index = self ._rand_start_index (rand )
231+ prompt_tokens = self ._create_prompt (bucket .prefix_tokens , start_index )
232+ sample_percent = (
233+ bucket .bucket_weight / bucket .prefix_count / total_weight
234+ )
235+ sample_count = sample_percent * self .config .samples
236+ for _ in range (int (round (sample_count ))):
237+ prompts .append (prompt_tokens )
238+
239+ rand .shuffle (prompts )
240+ return prompts
241+
196242 def _create_prompt (
197243 self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
198244 ) -> list [int ]:
0 commit comments