66from random import Random
77from typing import Any
88
9+ import numpy as np
910import yaml
10- from datasets import Features , IterableDataset , Value
11+ from datasets import DatasetInfo , Features , IterableDataset , Value
12+ from datasets .iterable_dataset import _BaseExamplesIterable
1113from faker import Faker
1214from pydantic import ConfigDict , Field , ValidationError , model_validator
1315from transformers import PreTrainedTokenizerBase
2123from guidellm .utils import IntegerRangeSampler
2224
2325__all__ = [
26+ "SyntheticTextDataset" ,
2427 "SyntheticTextDatasetConfig" ,
2528 "SyntheticTextDatasetDeserializer" ,
26- "SyntheticTextGenerator" ,
2729 "SyntheticTextPrefixBucketConfig" ,
2830]
2931
@@ -121,29 +123,34 @@ def check_prefix_options(self) -> SyntheticTextDatasetConfig:
121123 return self
122124
123125
124- class SyntheticTextGenerator :
126+ class _SyntheticTextExamplesIterable (_BaseExamplesIterable ):
127+ """Custom examples iterable for synthetic text generation."""
128+
125129 def __init__ (
126130 self ,
127131 config : SyntheticTextDatasetConfig ,
128132 processor : PreTrainedTokenizerBase ,
129- random_seed : int = 42 ,
133+ random_seed : int ,
130134 ):
135+ super ().__init__ ()
131136 self .config = config
132137 self .processor = processor
133138 self .random_seed = random_seed
139+ self .iteration_count = 0
134140
135- def __iter__ (self ) -> Iterator [dict [str , Any ]]:
136- samples_generated = 0
141+ def __iter__ (self ) -> Iterator [tuple [int , dict [str , Any ]]]:
142+ iter_random_seed = self .random_seed + self .iteration_count
143+ self .iteration_count += 1
137144
138145 faker = Faker ()
139- faker .seed_instance (self . random_seed )
146+ faker .seed_instance (iter_random_seed )
140147 prompt_tokens_sampler = iter (
141148 IntegerRangeSampler (
142149 average = self .config .prompt_tokens ,
143150 variance = self .config .prompt_tokens_stdev ,
144151 min_value = self .config .prompt_tokens_min ,
145152 max_value = self .config .prompt_tokens_max ,
146- random_seed = self . random_seed ,
153+ random_seed = iter_random_seed ,
147154 )
148155 )
149156 output_tokens_sampler = iter (
@@ -152,27 +159,77 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
152159 variance = self .config .output_tokens_stdev ,
153160 min_value = self .config .output_tokens_min ,
154161 max_value = self .config .output_tokens_max ,
155- random_seed = self . random_seed + 1 , # ensure diff dist from prompts
162+ random_seed = iter_random_seed + 1 , # ensure diff dist from prompts
156163 )
157164 )
158165
159166 # Create a shared prefix if specified
160- rand = Random (self . random_seed + 3 )
167+ rand = Random (iter_random_seed + 3 )
161168 prefix_iter = self ._create_prefix_iter (faker , rand )
169+ samples_count = 0
162170
163171 while True :
164172 prompt_tokens_count = next (prompt_tokens_sampler )
165173 output_tokens_count = next (output_tokens_sampler )
166174
167- yield {
168- "prefix" : next (prefix_iter ),
169- "prompt" : self ._create_prompt (
170- prompt_tokens_count , faker , f"{ samples_generated } "
171- ),
172- "prompt_tokens_count" : prompt_tokens_count ,
173- "output_tokens_count" : output_tokens_count ,
175+ yield (
176+ samples_count ,
177+ {
178+ "prefix" : next (prefix_iter ),
179+ "prompt" : self ._create_prompt (
180+ prompt_tokens_count ,
181+ faker ,
182+ f"{ self .iteration_count } { samples_count } " ,
183+ ),
184+ "prompt_tokens_count" : prompt_tokens_count ,
185+ "output_tokens_count" : output_tokens_count ,
186+ },
187+ )
188+ samples_count += 1
189+
190+ @property
191+ def is_typed (self ) -> bool :
192+ return True
193+
194+ @property
195+ def features (self ) -> Features :
196+ return Features (
197+ {
198+ "prefix" : Value ("string" ),
199+ "prompt" : Value ("string" ),
200+ "prompt_tokens_count" : Value ("int32" ),
201+ "output_tokens_count" : Value ("int32" ),
174202 }
175- samples_generated += 1
203+ )
204+
205+ @property
206+ def num_shards (self ) -> int :
207+ return 1
208+
209+ def shuffle_data_sources (
210+ self ,
211+ generator : np .random .Generator , # noqa: ARG002
212+ ) -> _SyntheticTextExamplesIterable :
213+ """Return self since synthetic data doesn't have fixed sources to shuffle."""
214+ return self
215+
216+ def shard_data_sources (
217+ self ,
218+ num_shards : int , # noqa: ARG002
219+ index : int , # noqa: ARG002
220+ contiguous : bool = True , # noqa: ARG002
221+ ) -> _SyntheticTextExamplesIterable :
222+ """Return self since synthetic data generation is infinite and stateless."""
223+ return self
224+
225+ def load_state_dict (self , state_dict : dict ) -> None :
226+ """Load the state from a state dict."""
227+ self .iteration_count = state_dict .get ("iteration_count" , 0 )
228+
229+ def _init_state_dict (self ) -> dict :
230+ """Initialize the state dict for the iterable."""
231+ self ._state_dict = {"iteration_count" : self .iteration_count }
232+ return self ._state_dict
176233
177234 def _create_prompt (
178235 self , prompt_tokens_count : int , faker : Faker , unique : str = ""
@@ -226,6 +283,39 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
226283 yield rand .choice (prefixes )
227284
228285
286+ class SyntheticTextDataset (IterableDataset ):
287+ def __init__ (
288+ self ,
289+ config : SyntheticTextDatasetConfig ,
290+ processor : PreTrainedTokenizerBase ,
291+ random_seed : int = 42 ,
292+ ):
293+ self .config = config
294+ self .processor = processor
295+ self .random_seed = random_seed
296+
297+ # Create the examples iterable
298+ ex_iterable = _SyntheticTextExamplesIterable (
299+ config = config ,
300+ processor = processor ,
301+ random_seed = random_seed ,
302+ )
303+
304+ # Initialize parent with proper ex_iterable
305+ super ().__init__ (
306+ ex_iterable = ex_iterable ,
307+ info = DatasetInfo (
308+ description = "Synthetic text dataset generator" ,
309+ features = ex_iterable .features ,
310+ ),
311+ )
312+
313+ def set_epoch (self , epoch : int ):
314+ """Set the epoch for the dataset iteration."""
315+ if isinstance (self ._ex_iterable , _SyntheticTextExamplesIterable ):
316+ self ._ex_iterable .iteration_count = epoch
317+
318+
229319@DatasetDeserializerFactory .register ("synthetic_text" )
230320class SyntheticTextDatasetDeserializer (DatasetDeserializer ):
231321 def __call__ (
@@ -254,21 +344,10 @@ def __call__(
254344 f"got { data } "
255345 )
256346
257- return IterableDataset .from_generator (
258- SyntheticTextGenerator ,
259- gen_kwargs = {
260- "config" : data ,
261- "processor" : processor_factory (),
262- "random_seed" : random_seed ,
263- },
264- features = Features (
265- {
266- "prefix" : Value ("string" ),
267- "prompt" : Value ("string" ),
268- "prompt_tokens_count" : Value ("int32" ),
269- "output_tokens_count" : Value ("int32" ),
270- }
271- ),
347+ return SyntheticTextDataset (
348+ config = data ,
349+ processor = processor_factory (),
350+ random_seed = random_seed ,
272351 )
273352
274353 def _load_config_dict (self , data : Any ) -> SyntheticTextDatasetConfig | None :
0 commit comments