Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 111 additions & 17 deletions src/guidellm/dataset/synthetic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import json
import os
import random
import re
import hashlib
import time
from collections.abc import Iterable, Iterator
from itertools import cycle
from pathlib import Path
from typing import Any, Literal, Optional, Union

import yaml
import numpy as np # type: ignore[import]
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
)
from loguru import logger
from pydantic import BaseModel, Field
from transformers import PreTrainedTokenizerBase # type: ignore[import]

Expand Down Expand Up @@ -144,6 +150,15 @@ def __init__(
self.text_creator = EndlessTextCreator(
data=config.source,
)
# Pre-tokenize entire source once and cache per (tokenizer, source)
start_time = time.perf_counter()
self._cached_tokens: list[int] = self._load_or_build_token_cache()
elapsed = (time.perf_counter() - start_time) * 1000.0
logger.info(
"Synthetic: token cache ready | tokens={} | took_ms={:.2f}",
len(self._cached_tokens),
elapsed,
)

def __iter__(
self,
Expand Down Expand Up @@ -171,15 +186,16 @@ def __iter__(
rand = random.Random(self.random_seed + 2) # noqa: S311
unique_prefix_iter = cycle(self.processor.get_vocab().values())

prefix_index = rand.randint(0, len(self.text_creator.words))
prefix_index = rand.randint(0, max(len(self._cached_tokens) - 1, 0))
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)

sample_start_time = time.perf_counter()
for _, prompt_tokens, output_tokens in zip(
range(self.config.samples),
prompt_tokens_sampler,
output_tokens_sampler,
):
start_index = rand.randint(0, len(self.text_creator.words))
start_index = rand.randint(0, max(len(self._cached_tokens) - 1, 0))
prompt_text = self.processor.decode(
prefix_tokens
+ self._create_prompt(
Expand All @@ -192,31 +208,109 @@ def __iter__(
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
"output_tokens_count": output_tokens,
}
elapsed_samples = (time.perf_counter() - sample_start_time) * 1000.0
logger.info(
"Synthetic: generated_samples={} | took_ms={:.2f} | avg_ms_per_sample={:.4f}",
self.config.samples,
elapsed_samples,
elapsed_samples / max(self.config.samples, 1),
)

def _create_prompt(
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None
) -> list[int]:
if prompt_tokens <= 0:
return []

left = start_index
right = start_index + 4 * prompt_tokens
start_tokens = [unique_prefix] if unique_prefix else []
# Determine how many tokens to take from cache, accounting for optional unique prefix
remaining = prompt_tokens - (1 if unique_prefix is not None else 0)
if remaining < 0:
remaining = 0

sampled = self._take_tokens(start_index, remaining)
if unique_prefix is not None:
return [unique_prefix] + sampled
return sampled

def _take_tokens(self, start_index: int, count: int) -> list[int]:
if count <= 0:
return []
tokens = self._cached_tokens
n = len(tokens)
if n == 0:
return []
# Wrap-around contiguous sampling
result: list[int] = []
base = start_index % n
for offset in range(count):
result.append(tokens[(base + offset) % n])
return result

def _load_or_build_token_cache(self) -> list[int]:
# Create cache directory
cache_dir = Path(
os.getenv("XDG_CACHE_HOME", str(Path.home() / ".cache"))
) / "guidellm" / "synthetic_tokens"
cache_dir.mkdir(parents=True, exist_ok=True)

# Compute a stable tokenizer identifier and source digest
tokenizer_id = self._tokenizer_identifier(self.processor)
source_digest = hashlib.sha1(
self.text_creator.filtered_text.encode("utf-8", errors="ignore")
).hexdigest()

safe_tokenizer_id = re.sub(r"[^A-Za-z0-9_.-]", "_", tokenizer_id)
cache_path = cache_dir / f"{safe_tokenizer_id}-{source_digest}.npy"

if cache_path.exists():
try:
arr = np.load(cache_path)
# Ensure 1-D integer array
arr = np.asarray(arr, dtype=np.int64).reshape(-1)
logger.debug(
"Synthetic: loaded token cache from {} | tokens={}",
str(cache_path),
arr.size,
)
return arr.astype(int).tolist()
except Exception:
# If loading fails, rebuild below
pass

# Build tokens once from full filtered text
# Avoid adding special tokens so spans don't include BOS/EOS markers repeatedly
build_start = time.perf_counter()
full_tokens = self.processor.encode(
self.text_creator.filtered_text,
add_special_tokens=False,
)
build_elapsed = (time.perf_counter() - build_start) * 1000.0
logger.info(
"Synthetic: built token cache in {:.2f} ms | tokens={}",
build_elapsed,
len(full_tokens),
)

while left < right:
mid = (left + right) // 2
test_prompt = self.text_creator.create_text(start_index, mid - start_index)
test_tokens = start_tokens + self.processor.encode(test_prompt)
# Persist to cache
try:
np.save(cache_path, np.asarray(full_tokens, dtype=np.int32))
logger.debug(
"Synthetic: saved token cache to {} | bytes≈{}",
str(cache_path),
int(np.asarray(full_tokens, dtype=np.int32).nbytes),
)
except Exception:
# Best effort; ignore cache write failures
pass

if len(test_tokens) == prompt_tokens:
return test_tokens
elif len(test_tokens) < prompt_tokens:
left = mid + 1
else:
right = mid
return full_tokens

final_text = self.text_creator.create_text(start_index, left - start_index)
return start_tokens + self.processor.encode(final_text)
@staticmethod
def _tokenizer_identifier(tokenizer: PreTrainedTokenizerBase) -> str:
name_or_path = getattr(tokenizer, "name_or_path", None) or "unknown"
vocab_size = getattr(tokenizer, "vocab_size", None)
cls_name = tokenizer.__class__.__name__
return f"{cls_name}-{name_or_path}-{vocab_size}"


class SyntheticDatasetCreator(DatasetCreator):
Expand Down
Loading