1
1
import json
2
2
import random
3
- from collections .abc import Iterable , Iterator
3
+ from collections .abc import Iterable , Iterator , Sequence
4
4
from itertools import cycle
5
5
from pathlib import Path
6
6
from typing import Any , Optional , TypedDict , Union
19
19
from guidellm .utils import EndlessTextCreator , IntegerRangeSampler , check_load_processor
20
20
21
21
__all__ = [
22
+ "PrefixBucketConfig" ,
22
23
"SyntheticDatasetConfig" ,
23
24
"SyntheticDatasetCreator" ,
24
25
"SyntheticTextItemsGenerator" ,
25
26
]
26
27
27
28
28
- class SyntheticDatasetConfig (BaseModel ):
29
+ class PrefixBucketConfig (BaseModel ):
30
+ bucket_weight : int = Field (
31
+ description = "Weight of this bucket in the overall distribution." ,
32
+ gt = 0 ,
33
+ default = 100 ,
34
+ )
35
+ prefix_count : int = Field (
36
+ description = "The number of unique prefixs to generate for this bucket." ,
37
+ ge = 1 ,
38
+ default = 1 ,
39
+ )
29
40
prefix_tokens : int = Field (
30
- description = "The number of shared prefix tokens to prepend to each prompt ." ,
41
+ description = "The number of prefix tokens per-prompt for this bucket ." ,
31
42
ge = 0 ,
32
43
default = 0 ,
33
44
)
45
+
46
+
47
+ class SyntheticDatasetConfig (BaseModel ):
48
+ prefix_buckets : Optional [list [PrefixBucketConfig ]] = Field (
49
+ description = "Buckets for the prefix tokens distribution." ,
50
+ default = None ,
51
+ )
34
52
prompt_tokens : int = Field (
35
53
description = "The average number of text tokens generated for prompts." ,
36
54
gt = 0 ,
@@ -190,11 +208,9 @@ def __iter__(
190
208
)
191
209
# ensure diff distribution from output tokens
192
210
rand = random .Random (self .random_seed + 2 ) # noqa: S311
211
+ shared_prefix_iter = iter (self ._create_prefixes (rand ))
193
212
unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
194
213
195
- prefix_index = rand .randint (0 , len (self .text_creator .words ))
196
- prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
197
-
198
214
for _ , turns in zip (range (self .config .samples ), turns_sampler ):
199
215
row : SyntheticDatasetRow = {
200
216
"prompt" : [],
@@ -207,6 +223,7 @@ def __iter__(
207
223
output_tokens_sampler ,
208
224
):
209
225
start_index = rand .randint (0 , len (self .text_creator .words ))
226
+ prefix_tokens = next (shared_prefix_iter , [])
210
227
# Append the prefix tokens only for the first turn
211
228
if i == 0 :
212
229
prompt_text = self .processor .decode (
@@ -217,7 +234,7 @@ def __iter__(
217
234
skip_special_tokens = True ,
218
235
)
219
236
row ["prompt" ].append (prompt_text )
220
- row ["prompt_tokens_count" ].append (self . config . prefix_tokens + prompt_tokens )
237
+ row ["prompt_tokens_count" ].append (len ( prefix_tokens ) + prompt_tokens )
221
238
row ["output_tokens_count" ].append (output_tokens )
222
239
else :
223
240
prompt_text = self .processor .decode (
@@ -232,6 +249,36 @@ def __iter__(
232
249
233
250
yield row
234
251
252
+ def _rand_start_index (self , rand : random .Random ) -> int :
253
+ """Generate a random start index for text generation."""
254
+ return rand .randint (0 , len (self .text_creator .words ) - 1 )
255
+
256
+ def _create_prefixes (self , rand : random .Random ) -> Sequence [list [int ]]:
257
+ """Create an iterator for shared prefix tokens."""
258
+ buckets = self .config .prefix_buckets
259
+
260
+ if not buckets :
261
+ return []
262
+
263
+ total_weight = sum (bucket .bucket_weight for bucket in buckets )
264
+ if total_weight <= 0 :
265
+ raise ValueError ("Total weight of prefix buckets must be greater than 0." )
266
+
267
+ prompts = []
268
+ for bucket in buckets :
269
+ for _ in range (bucket .prefix_count ):
270
+ start_index = self ._rand_start_index (rand )
271
+ prompt_tokens = self ._create_prompt (bucket .prefix_tokens , start_index )
272
+ sample_percent = (
273
+ bucket .bucket_weight / bucket .prefix_count / total_weight
274
+ )
275
+ sample_count = sample_percent * self .config .samples
276
+ for _ in range (int (round (sample_count ))):
277
+ prompts .append (prompt_tokens )
278
+
279
+ rand .shuffle (prompts )
280
+ return prompts
281
+
235
282
def _create_prompt (
236
283
self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
237
284
) -> list [int ]:
0 commit comments