diff --git a/src/guidellm/dataset/synthetic.py b/src/guidellm/dataset/synthetic.py index 8c30f0f7..a9c4d4a3 100644 --- a/src/guidellm/dataset/synthetic.py +++ b/src/guidellm/dataset/synthetic.py @@ -1,17 +1,23 @@ import json +import os import random +import re +import hashlib +import time from collections.abc import Iterable, Iterator from itertools import cycle from pathlib import Path from typing import Any, Literal, Optional, Union import yaml +import numpy as np # type: ignore[import] from datasets import ( Dataset, DatasetDict, IterableDataset, IterableDatasetDict, ) +from loguru import logger from pydantic import BaseModel, Field from transformers import PreTrainedTokenizerBase # type: ignore[import] @@ -144,6 +150,15 @@ def __init__( self.text_creator = EndlessTextCreator( data=config.source, ) + # Pre-tokenize entire source once and cache per (tokenizer, source) + start_time = time.perf_counter() + self._cached_tokens: list[int] = self._load_or_build_token_cache() + elapsed = (time.perf_counter() - start_time) * 1000.0 + logger.info( + "Synthetic: token cache ready | tokens={} | took_ms={:.2f}", + len(self._cached_tokens), + elapsed, + ) def __iter__( self, @@ -171,15 +186,16 @@ def __iter__( rand = random.Random(self.random_seed + 2) # noqa: S311 unique_prefix_iter = cycle(self.processor.get_vocab().values()) - prefix_index = rand.randint(0, len(self.text_creator.words)) + prefix_index = rand.randint(0, max(len(self._cached_tokens) - 1, 0)) prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index) + sample_start_time = time.perf_counter() for _, prompt_tokens, output_tokens in zip( range(self.config.samples), prompt_tokens_sampler, output_tokens_sampler, ): - start_index = rand.randint(0, len(self.text_creator.words)) + start_index = rand.randint(0, max(len(self._cached_tokens) - 1, 0)) prompt_text = self.processor.decode( prefix_tokens + self._create_prompt( @@ -192,6 +208,13 @@ def __iter__( "prompt_tokens_count": self.config.prefix_tokens + prompt_tokens, "output_tokens_count": output_tokens, } + elapsed_samples = (time.perf_counter() - sample_start_time) * 1000.0 + logger.info( + "Synthetic: generated_samples={} | took_ms={:.2f} | avg_ms_per_sample={:.4f}", + self.config.samples, + elapsed_samples, + elapsed_samples / max(self.config.samples, 1), + ) def _create_prompt( self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None @@ -199,24 +222,95 @@ def _create_prompt( if prompt_tokens <= 0: return [] - left = start_index - right = start_index + 4 * prompt_tokens - start_tokens = [unique_prefix] if unique_prefix else [] + # Determine how many tokens to take from cache, accounting for optional unique prefix + remaining = prompt_tokens - (1 if unique_prefix is not None else 0) + if remaining < 0: + remaining = 0 + + sampled = self._take_tokens(start_index, remaining) + if unique_prefix is not None: + return [unique_prefix] + sampled + return sampled + + def _take_tokens(self, start_index: int, count: int) -> list[int]: + if count <= 0: + return [] + tokens = self._cached_tokens + n = len(tokens) + if n == 0: + return [] + # Wrap-around contiguous sampling + result: list[int] = [] + base = start_index % n + for offset in range(count): + result.append(tokens[(base + offset) % n]) + return result + + def _load_or_build_token_cache(self) -> list[int]: + # Create cache directory + cache_dir = Path( + os.getenv("XDG_CACHE_HOME", str(Path.home() / ".cache")) + ) / "guidellm" / "synthetic_tokens" + cache_dir.mkdir(parents=True, exist_ok=True) + + # Compute a stable tokenizer identifier and source digest + tokenizer_id = self._tokenizer_identifier(self.processor) + source_digest = hashlib.sha1( + self.text_creator.filtered_text.encode("utf-8", errors="ignore") + ).hexdigest() + + safe_tokenizer_id = re.sub(r"[^A-Za-z0-9_.-]", "_", tokenizer_id) + cache_path = cache_dir / f"{safe_tokenizer_id}-{source_digest}.npy" + + if cache_path.exists(): + try: + arr = np.load(cache_path) + # Ensure 1-D integer array + arr = np.asarray(arr, dtype=np.int64).reshape(-1) + logger.debug( + "Synthetic: loaded token cache from {} | tokens={}", + str(cache_path), + arr.size, + ) + return arr.astype(int).tolist() + except Exception: + # If loading fails, rebuild below + pass + + # Build tokens once from full filtered text + # Avoid adding special tokens so spans don't include BOS/EOS markers repeatedly + build_start = time.perf_counter() + full_tokens = self.processor.encode( + self.text_creator.filtered_text, + add_special_tokens=False, + ) + build_elapsed = (time.perf_counter() - build_start) * 1000.0 + logger.info( + "Synthetic: built token cache in {:.2f} ms | tokens={}", + build_elapsed, + len(full_tokens), + ) - while left < right: - mid = (left + right) // 2 - test_prompt = self.text_creator.create_text(start_index, mid - start_index) - test_tokens = start_tokens + self.processor.encode(test_prompt) + # Persist to cache + try: + np.save(cache_path, np.asarray(full_tokens, dtype=np.int32)) + logger.debug( + "Synthetic: saved token cache to {} | bytes≈{}", + str(cache_path), + int(np.asarray(full_tokens, dtype=np.int32).nbytes), + ) + except Exception: + # Best effort; ignore cache write failures + pass - if len(test_tokens) == prompt_tokens: - return test_tokens - elif len(test_tokens) < prompt_tokens: - left = mid + 1 - else: - right = mid + return full_tokens - final_text = self.text_creator.create_text(start_index, left - start_index) - return start_tokens + self.processor.encode(final_text) + @staticmethod + def _tokenizer_identifier(tokenizer: PreTrainedTokenizerBase) -> str: + name_or_path = getattr(tokenizer, "name_or_path", None) or "unknown" + vocab_size = getattr(tokenizer, "vocab_size", None) + cls_name = tokenizer.__class__.__name__ + return f"{cls_name}-{name_or_path}-{vocab_size}" class SyntheticDatasetCreator(DatasetCreator):