Skip to content

Commit 57df288

Browse files
committed
Speed up synthetic data sampling
Signed-off-by: David Whyte-Gray <[email protected]>
1 parent f8f6f9d commit 57df288

File tree

1 file changed

+111
-17
lines changed

1 file changed

+111
-17
lines changed

src/guidellm/dataset/synthetic.py

Lines changed: 111 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
import json
2+
import os
23
import random
4+
import re
5+
import hashlib
6+
import time
37
from collections.abc import Iterable, Iterator
48
from itertools import cycle
59
from pathlib import Path
610
from typing import Any, Literal, Optional, Union
711

812
import yaml
13+
import numpy as np # type: ignore[import]
914
from datasets import (
1015
Dataset,
1116
DatasetDict,
1217
IterableDataset,
1318
IterableDatasetDict,
1419
)
20+
from loguru import logger
1521
from pydantic import BaseModel, Field
1622
from transformers import PreTrainedTokenizerBase # type: ignore[import]
1723

@@ -144,6 +150,15 @@ def __init__(
144150
self.text_creator = EndlessTextCreator(
145151
data=config.source,
146152
)
153+
# Pre-tokenize entire source once and cache per (tokenizer, source)
154+
start_time = time.perf_counter()
155+
self._cached_tokens: list[int] = self._load_or_build_token_cache()
156+
elapsed = (time.perf_counter() - start_time) * 1000.0
157+
logger.info(
158+
"Synthetic: token cache ready | tokens={} | took_ms={:.2f}",
159+
len(self._cached_tokens),
160+
elapsed,
161+
)
147162

148163
def __iter__(
149164
self,
@@ -171,15 +186,16 @@ def __iter__(
171186
rand = random.Random(self.random_seed + 2) # noqa: S311
172187
unique_prefix_iter = cycle(self.processor.get_vocab().values())
173188

174-
prefix_index = rand.randint(0, len(self.text_creator.words))
189+
prefix_index = rand.randint(0, max(len(self._cached_tokens) - 1, 0))
175190
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
176191

192+
sample_start_time = time.perf_counter()
177193
for _, prompt_tokens, output_tokens in zip(
178194
range(self.config.samples),
179195
prompt_tokens_sampler,
180196
output_tokens_sampler,
181197
):
182-
start_index = rand.randint(0, len(self.text_creator.words))
198+
start_index = rand.randint(0, max(len(self._cached_tokens) - 1, 0))
183199
prompt_text = self.processor.decode(
184200
prefix_tokens
185201
+ self._create_prompt(
@@ -192,31 +208,109 @@ def __iter__(
192208
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
193209
"output_tokens_count": output_tokens,
194210
}
211+
elapsed_samples = (time.perf_counter() - sample_start_time) * 1000.0
212+
logger.info(
213+
"Synthetic: generated_samples={} | took_ms={:.2f} | avg_ms_per_sample={:.4f}",
214+
self.config.samples,
215+
elapsed_samples,
216+
elapsed_samples / max(self.config.samples, 1),
217+
)
195218

196219
def _create_prompt(
197220
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
198221
) -> list[int]:
199222
if prompt_tokens <= 0:
200223
return []
201224

202-
left = start_index
203-
right = start_index + 4 * prompt_tokens
204-
start_tokens = [unique_prefix] if unique_prefix else []
225+
# Determine how many tokens to take from cache, accounting for optional unique prefix
226+
remaining = prompt_tokens - (1 if unique_prefix is not None else 0)
227+
if remaining < 0:
228+
remaining = 0
229+
230+
sampled = self._take_tokens(start_index, remaining)
231+
if unique_prefix is not None:
232+
return [unique_prefix] + sampled
233+
return sampled
234+
235+
def _take_tokens(self, start_index: int, count: int) -> list[int]:
236+
if count <= 0:
237+
return []
238+
tokens = self._cached_tokens
239+
n = len(tokens)
240+
if n == 0:
241+
return []
242+
# Wrap-around contiguous sampling
243+
result: list[int] = []
244+
base = start_index % n
245+
for offset in range(count):
246+
result.append(tokens[(base + offset) % n])
247+
return result
248+
249+
def _load_or_build_token_cache(self) -> list[int]:
250+
# Create cache directory
251+
cache_dir = Path(
252+
os.getenv("XDG_CACHE_HOME", str(Path.home() / ".cache"))
253+
) / "guidellm" / "synthetic_tokens"
254+
cache_dir.mkdir(parents=True, exist_ok=True)
255+
256+
# Compute a stable tokenizer identifier and source digest
257+
tokenizer_id = self._tokenizer_identifier(self.processor)
258+
source_digest = hashlib.sha1(
259+
self.text_creator.filtered_text.encode("utf-8", errors="ignore")
260+
).hexdigest()
261+
262+
safe_tokenizer_id = re.sub(r"[^A-Za-z0-9_.-]", "_", tokenizer_id)
263+
cache_path = cache_dir / f"{safe_tokenizer_id}-{source_digest}.npy"
264+
265+
if cache_path.exists():
266+
try:
267+
arr = np.load(cache_path)
268+
# Ensure 1-D integer array
269+
arr = np.asarray(arr, dtype=np.int64).reshape(-1)
270+
logger.debug(
271+
"Synthetic: loaded token cache from {} | tokens={}",
272+
str(cache_path),
273+
arr.size,
274+
)
275+
return arr.astype(int).tolist()
276+
except Exception:
277+
# If loading fails, rebuild below
278+
pass
279+
280+
# Build tokens once from full filtered text
281+
# Avoid adding special tokens so spans don't include BOS/EOS markers repeatedly
282+
build_start = time.perf_counter()
283+
full_tokens = self.processor.encode(
284+
self.text_creator.filtered_text,
285+
add_special_tokens=False,
286+
)
287+
build_elapsed = (time.perf_counter() - build_start) * 1000.0
288+
logger.info(
289+
"Synthetic: built token cache in {:.2f} ms | tokens={}",
290+
build_elapsed,
291+
len(full_tokens),
292+
)
205293

206-
while left < right:
207-
mid = (left + right) // 2
208-
test_prompt = self.text_creator.create_text(start_index, mid - start_index)
209-
test_tokens = start_tokens + self.processor.encode(test_prompt)
294+
# Persist to cache
295+
try:
296+
np.save(cache_path, np.asarray(full_tokens, dtype=np.int32))
297+
logger.debug(
298+
"Synthetic: saved token cache to {} | bytes≈{}",
299+
str(cache_path),
300+
int(np.asarray(full_tokens, dtype=np.int32).nbytes),
301+
)
302+
except Exception:
303+
# Best effort; ignore cache write failures
304+
pass
210305

211-
if len(test_tokens) == prompt_tokens:
212-
return test_tokens
213-
elif len(test_tokens) < prompt_tokens:
214-
left = mid + 1
215-
else:
216-
right = mid
306+
return full_tokens
217307

218-
final_text = self.text_creator.create_text(start_index, left - start_index)
219-
return start_tokens + self.processor.encode(final_text)
308+
@staticmethod
309+
def _tokenizer_identifier(tokenizer: PreTrainedTokenizerBase) -> str:
310+
name_or_path = getattr(tokenizer, "name_or_path", None) or "unknown"
311+
vocab_size = getattr(tokenizer, "vocab_size", None)
312+
cls_name = tokenizer.__class__.__name__
313+
return f"{cls_name}-{name_or_path}-{vocab_size}"
220314

221315

222316
class SyntheticDatasetCreator(DatasetCreator):

0 commit comments

Comments
 (0)