Skip to content

Commit 9f2eb86

Browse files
committed
Fix synthetic dataset so it can preserve randomness across benchmarks run in the same session by enforcing set_epoch across the data loader and iterators chain
Signed-off-by: Mark Kurtz <[email protected]>
1 parent 30c7b92 commit 9f2eb86

File tree

4 files changed

+143
-54
lines changed

4 files changed

+143
-54
lines changed

src/guidellm/data/deserializers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
InMemoryJsonStrDatasetDeserializer,
2323
)
2424
from .synthetic import (
25+
SyntheticTextDataset,
2526
SyntheticTextDatasetConfig,
2627
SyntheticTextDatasetDeserializer,
27-
SyntheticTextGenerator,
2828
SyntheticTextPrefixBucketConfig,
2929
)
3030

@@ -44,9 +44,9 @@
4444
"InMemoryJsonStrDatasetDeserializer",
4545
"JSONFileDatasetDeserializer",
4646
"ParquetFileDatasetDeserializer",
47+
"SyntheticTextDataset",
4748
"SyntheticTextDatasetConfig",
4849
"SyntheticTextDatasetDeserializer",
49-
"SyntheticTextGenerator",
5050
"SyntheticTextPrefixBucketConfig",
5151
"TarFileDatasetDeserializer",
5252
"TextFileDatasetDeserializer",

src/guidellm/data/deserializers/synthetic.py

Lines changed: 112 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from random import Random
77
from typing import Any
88

9+
import numpy as np
910
import yaml
10-
from datasets import Features, IterableDataset, Value
11+
from datasets import DatasetInfo, Features, IterableDataset, Value
12+
from datasets.iterable_dataset import _BaseExamplesIterable
1113
from faker import Faker
1214
from pydantic import ConfigDict, Field, ValidationError, model_validator
1315
from transformers import PreTrainedTokenizerBase
@@ -21,9 +23,9 @@
2123
from guidellm.utils import IntegerRangeSampler
2224

2325
__all__ = [
26+
"SyntheticTextDataset",
2427
"SyntheticTextDatasetConfig",
2528
"SyntheticTextDatasetDeserializer",
26-
"SyntheticTextGenerator",
2729
"SyntheticTextPrefixBucketConfig",
2830
]
2931

@@ -121,29 +123,34 @@ def check_prefix_options(self) -> SyntheticTextDatasetConfig:
121123
return self
122124

123125

124-
class SyntheticTextGenerator:
126+
class _SyntheticTextExamplesIterable(_BaseExamplesIterable):
127+
"""Custom examples iterable for synthetic text generation."""
128+
125129
def __init__(
126130
self,
127131
config: SyntheticTextDatasetConfig,
128132
processor: PreTrainedTokenizerBase,
129-
random_seed: int = 42,
133+
random_seed: int,
130134
):
135+
super().__init__()
131136
self.config = config
132137
self.processor = processor
133138
self.random_seed = random_seed
139+
self.iteration_count = 0
134140

135-
def __iter__(self) -> Iterator[dict[str, Any]]:
136-
samples_generated = 0
141+
def __iter__(self) -> Iterator[tuple[int, dict[str, Any]]]:
142+
iter_random_seed = self.random_seed + self.iteration_count
143+
self.iteration_count += 1
137144

138145
faker = Faker()
139-
faker.seed_instance(self.random_seed)
146+
faker.seed_instance(iter_random_seed)
140147
prompt_tokens_sampler = iter(
141148
IntegerRangeSampler(
142149
average=self.config.prompt_tokens,
143150
variance=self.config.prompt_tokens_stdev,
144151
min_value=self.config.prompt_tokens_min,
145152
max_value=self.config.prompt_tokens_max,
146-
random_seed=self.random_seed,
153+
random_seed=iter_random_seed,
147154
)
148155
)
149156
output_tokens_sampler = iter(
@@ -152,27 +159,77 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
152159
variance=self.config.output_tokens_stdev,
153160
min_value=self.config.output_tokens_min,
154161
max_value=self.config.output_tokens_max,
155-
random_seed=self.random_seed + 1, # ensure diff dist from prompts
162+
random_seed=iter_random_seed + 1, # ensure diff dist from prompts
156163
)
157164
)
158165

159166
# Create a shared prefix if specified
160-
rand = Random(self.random_seed + 3)
167+
rand = Random(iter_random_seed + 3)
161168
prefix_iter = self._create_prefix_iter(faker, rand)
169+
samples_count = 0
162170

163171
while True:
164172
prompt_tokens_count = next(prompt_tokens_sampler)
165173
output_tokens_count = next(output_tokens_sampler)
166174

167-
yield {
168-
"prefix": next(prefix_iter),
169-
"prompt": self._create_prompt(
170-
prompt_tokens_count, faker, f"{samples_generated} "
171-
),
172-
"prompt_tokens_count": prompt_tokens_count,
173-
"output_tokens_count": output_tokens_count,
175+
yield (
176+
samples_count,
177+
{
178+
"prefix": next(prefix_iter),
179+
"prompt": self._create_prompt(
180+
prompt_tokens_count,
181+
faker,
182+
f"{self.iteration_count} {samples_count} ",
183+
),
184+
"prompt_tokens_count": prompt_tokens_count,
185+
"output_tokens_count": output_tokens_count,
186+
},
187+
)
188+
samples_count += 1
189+
190+
@property
191+
def is_typed(self) -> bool:
192+
return True
193+
194+
@property
195+
def features(self) -> Features:
196+
return Features(
197+
{
198+
"prefix": Value("string"),
199+
"prompt": Value("string"),
200+
"prompt_tokens_count": Value("int32"),
201+
"output_tokens_count": Value("int32"),
174202
}
175-
samples_generated += 1
203+
)
204+
205+
@property
206+
def num_shards(self) -> int:
207+
return 1
208+
209+
def shuffle_data_sources(
210+
self,
211+
generator: np.random.Generator, # noqa: ARG002
212+
) -> _SyntheticTextExamplesIterable:
213+
"""Return self since synthetic data doesn't have fixed sources to shuffle."""
214+
return self
215+
216+
def shard_data_sources(
217+
self,
218+
num_shards: int, # noqa: ARG002
219+
index: int, # noqa: ARG002
220+
contiguous: bool = True, # noqa: ARG002
221+
) -> _SyntheticTextExamplesIterable:
222+
"""Return self since synthetic data generation is infinite and stateless."""
223+
return self
224+
225+
def load_state_dict(self, state_dict: dict) -> None:
226+
"""Load the state from a state dict."""
227+
self.iteration_count = state_dict.get("iteration_count", 0)
228+
229+
def _init_state_dict(self) -> dict:
230+
"""Initialize the state dict for the iterable."""
231+
self._state_dict = {"iteration_count": self.iteration_count}
232+
return self._state_dict
176233

177234
def _create_prompt(
178235
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
@@ -226,6 +283,39 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
226283
yield rand.choice(prefixes)
227284

228285

286+
class SyntheticTextDataset(IterableDataset):
287+
def __init__(
288+
self,
289+
config: SyntheticTextDatasetConfig,
290+
processor: PreTrainedTokenizerBase,
291+
random_seed: int = 42,
292+
):
293+
self.config = config
294+
self.processor = processor
295+
self.random_seed = random_seed
296+
297+
# Create the examples iterable
298+
ex_iterable = _SyntheticTextExamplesIterable(
299+
config=config,
300+
processor=processor,
301+
random_seed=random_seed,
302+
)
303+
304+
# Initialize parent with proper ex_iterable
305+
super().__init__(
306+
ex_iterable=ex_iterable,
307+
info=DatasetInfo(
308+
description="Synthetic text dataset generator",
309+
features=ex_iterable.features,
310+
),
311+
)
312+
313+
def set_epoch(self, epoch: int):
314+
"""Set the epoch for the dataset iteration."""
315+
if isinstance(self._ex_iterable, _SyntheticTextExamplesIterable):
316+
self._ex_iterable.iteration_count = epoch
317+
318+
229319
@DatasetDeserializerFactory.register("synthetic_text")
230320
class SyntheticTextDatasetDeserializer(DatasetDeserializer):
231321
def __call__(
@@ -254,21 +344,10 @@ def __call__(
254344
f"got {data}"
255345
)
256346

257-
return IterableDataset.from_generator(
258-
SyntheticTextGenerator,
259-
gen_kwargs={
260-
"config": data,
261-
"processor": processor_factory(),
262-
"random_seed": random_seed,
263-
},
264-
features=Features(
265-
{
266-
"prefix": Value("string"),
267-
"prompt": Value("string"),
268-
"prompt_tokens_count": Value("int32"),
269-
"output_tokens_count": Value("int32"),
270-
}
271-
),
347+
return SyntheticTextDataset(
348+
config=data,
349+
processor=processor_factory(),
350+
random_seed=random_seed,
272351
)
273352

274353
def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None:

src/guidellm/data/loaders.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
self.precache: list[Any] | None = (
6464
list(self.generator(data_samples)) if data_samples else None
6565
)
66+
self.epoch = 0
6667

6768
def __iter__(self) -> Iterator[DataT]:
6869
worker_info = torch.utils.data.get_worker_info()
@@ -74,18 +75,29 @@ def __iter__(self) -> Iterator[DataT]:
7475
if (index + worker_index) % worker_modulus == 0:
7576
yield item
7677
else:
77-
yield from self.generator(modulus=worker_modulus, offset=worker_index)
78+
yield from self.generator(
79+
modulus=worker_modulus, offset=worker_index, epoch=self.epoch
80+
)
81+
82+
def set_epoch(self, epoch: int):
83+
self.epoch = epoch
7884

7985
def generator(
8086
self,
8187
max_items: int | None = None,
8288
modulus: int | None = None,
8389
offset: int | None = None,
90+
epoch: int = 0,
8491
) -> Iterator[DataT]:
8592
gen_count = 0
8693

8794
with contextlib.suppress(StopIteration):
88-
dataset_iters = [iter(dataset) for dataset in self.datasets]
95+
dataset_iters = []
96+
for dataset in self.datasets:
97+
if hasattr(dataset, "set_epoch"):
98+
with contextlib.suppress(Exception):
99+
dataset.set_epoch(epoch)
100+
dataset_iters.append(iter(dataset))
89101

90102
while max_items is None or gen_count < max_items:
91103
try:
@@ -152,6 +164,7 @@ def __init__(
152164
"num_workers": num_workers,
153165
"random_seed": random_seed,
154166
}
167+
self.epoch = 0
155168

156169
super().__init__(
157170
dataset=iterator,
@@ -163,6 +176,13 @@ def __init__(
163176
**kwargs,
164177
)
165178

179+
def __iter__(self):
180+
if isinstance(self.dataset, DatasetsIterator):
181+
self.dataset.set_epoch(self.epoch)
182+
self.epoch += 1
183+
184+
return super().__iter__()
185+
166186
@property
167187
def info(self) -> dict[str, Any]:
168188
return self._info

tests/unit/data/deserializers/test_synthetic.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
from guidellm.data.deserializers.deserializer import DataNotSupportedError
1515
from guidellm.data.deserializers.synthetic import (
16+
SyntheticTextDataset,
1617
SyntheticTextDatasetConfig,
1718
SyntheticTextDatasetDeserializer,
18-
SyntheticTextGenerator,
1919
SyntheticTextPrefixBucketConfig,
2020
)
2121

@@ -264,9 +264,7 @@ def test_generator_initialization(self, simple_config, mock_tokenizer):
264264
265265
### WRITTEN BY AI ###
266266
"""
267-
generator = SyntheticTextGenerator(
268-
simple_config, mock_tokenizer, random_seed=42
269-
)
267+
generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
270268

271269
assert generator.config == simple_config
272270
assert generator.processor == mock_tokenizer
@@ -278,9 +276,7 @@ def test_basic_iteration(self, simple_config, mock_tokenizer):
278276
279277
### WRITTEN BY AI ###
280278
"""
281-
generator = SyntheticTextGenerator(
282-
simple_config, mock_tokenizer, random_seed=42
283-
)
279+
generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
284280

285281
items = []
286282
for i, item in enumerate(generator):
@@ -310,9 +306,7 @@ def test_create_prompt_method(self, simple_config, mock_tokenizer):
310306
"""
311307
from faker import Faker
312308

313-
generator = SyntheticTextGenerator(
314-
simple_config, mock_tokenizer, random_seed=42
315-
)
309+
generator = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
316310
faker = Faker()
317311
faker.seed_instance(42)
318312

@@ -332,7 +326,7 @@ def test_prefix_tokens_integration(self, config_with_prefix, mock_tokenizer):
332326
333327
### WRITTEN BY AI ###
334328
"""
335-
generator = SyntheticTextGenerator(
329+
generator = SyntheticTextDataset(
336330
config_with_prefix, mock_tokenizer, random_seed=42
337331
)
338332

@@ -353,12 +347,8 @@ def test_random_seeding_consistency(self, simple_config, mock_tokenizer):
353347
### WRITTEN BY AI ###
354348
"""
355349
# Create two generators with same seed
356-
generator1 = SyntheticTextGenerator(
357-
simple_config, mock_tokenizer, random_seed=42
358-
)
359-
generator2 = SyntheticTextGenerator(
360-
simple_config, mock_tokenizer, random_seed=42
361-
)
350+
generator1 = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
351+
generator2 = SyntheticTextDataset(simple_config, mock_tokenizer, random_seed=42)
362352

363353
items1 = []
364354
items2 = []

0 commit comments

Comments
 (0)