diff --git a/docs/guides/datasets.md b/docs/guides/datasets.md index dba4f94b0..5bb511364 100644 --- a/docs/guides/datasets.md +++ b/docs/guides/datasets.md @@ -212,3 +212,237 @@ benchmark_generative_text(data=data, ...) - For lists of dictionaries, all items must have the same keys. - For lists of items, all elements must be of the same type. - A processor/tokenizer is only required if `GUIDELLM__PREFERRED_PROMPT_TOKENS_SOURCE="local"` or `GUIDELLM__PREFERRED_OUTPUT_TOKENS_SOURCE="local"` is set in the environment. In this case, the processor/tokenizer must be specified using the `--processor` argument. If not set, the processor/tokenizer will be set to the model passed in or retrieved from the server. + +## Preprocessing Datasets + +GuideLLM provides a preprocessing command that allows you to process datasets to have specific prompt and output token sizes. This is particularly useful when you need to standardize your dataset for benchmarking or when your dataset has prompts that don't match your target token requirements. + +The preprocessing command can: + +- Resize prompts to target token lengths +- Handle prompts that are shorter or longer than the target length using various strategies +- Map columns from your dataset to GuideLLM's expected column names +- Generate output token counts based on your configuration +- Save the processed dataset in various formats + +### Basic Usage + +```bash +guidellm preprocess dataset \ + \ + \ + --processor \ + --config +``` + +### Required Arguments + +| Argument | Description | +| ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | +| `DATA` | Path to the input dataset or Hugging Face dataset ID. Supports all dataset formats documented in the [Dataset Configurations](../datasets.md). | +| `OUTPUT_PATH` | Path to save the processed dataset, including file suffix (e.g., `processed_dataset.jsonl`, `output.csv`). | +| `--processor` | **Required.** Processor or tokenizer name/path for calculating token counts. Can be a Hugging Face model ID or local path. | +| `--config` | **Required.** Configuration specifying target token sizes. Can be a JSON string, key=value pairs, or file path (.json, .yaml, .yml, .config). | + +### Example + +```bash +guidellm preprocess dataset \ + "path/to/input_dataset.jsonl" \ + "path/to/processed_dataset.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,output_tokens=256,prefix_tokens_max=100" +``` + +### Configuration and Processor Options + +The `--config` parameter accepts a `PreprocessDatasetConfig` as a JSON string, key=value pairs, or a configuration file path (.json, .yaml, .yml, .config). This configuration is similar to the synthetic data configuration but includes additional fields specific to preprocessing. + +**PreprocessDatasetConfig Options:** + +- `prompt_tokens`: Average number of tokens in prompts. If nothing else is specified, all prompts will be resized to this number of tokens. +- `prompt_tokens_stdev`: Standard deviation for prompt tokens. If not supplied and min/max are not specified, no deviation is applied. If not supplied and min/max are specified, a uniform distribution is used. +- `prompt_tokens_min`: Minimum number of tokens in prompts. If unset and `prompt_tokens_stdev` is set, the minimum is 1. +- `prompt_tokens_max`: Maximum number of tokens in prompts. If unset and `prompt_tokens_stdev` is set, the maximum is 5 times the standard deviation. +- `output_tokens`: Average number of tokens in outputs. If nothing else is specified, all outputs will have this number of tokens. +- `output_tokens_stdev`: Standard deviation for output tokens. If not supplied and min/max are not specified, no deviation is applied. If not supplied and min/max are specified, a uniform distribution is used. +- `output_tokens_min`: Minimum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the minimum is 1. +- `output_tokens_max`: Maximum number of tokens in outputs. If unset and `output_tokens_stdev` is set, the maximum is 5 times the standard deviation. +- `prefix_tokens_max`: Maximum number of prefix tokens to keep. If set, prefixes will be trimmed to this maximum length. If not set, prefixes are kept as-is (unless `--include-prefix-in-token-count` is used, which disables prefix trimming). + +**Example configurations:** + +```bash +# Using key=value pairs +--config "prompt_tokens=512,output_tokens=256,prefix_tokens_max=100" + +# Using JSON string +--config '{"prompt_tokens": 512, "output_tokens": 256, "prefix_tokens_max": 100}' + +# Using a configuration file +--config "path/to/config.json" +``` + +The `--processor` argument specifies the tokenizer to use for calculating token counts. This is required because the preprocessing command needs to tokenize prompts to ensure they match the target token sizes. For information about using processors, including Hugging Face model IDs, local paths, and processor arguments, see the [Data Arguments Overview](../datasets.md#data-arguments-overview) section. + +### Column Mapping + +When your dataset uses non-standard column names, you can use `--data-column-mapper` to map your columns to GuideLLM's expected column names. This is particularly useful when: + +1. **Your dataset uses different column names** (e.g., `question` instead of `prompt`, `instruction` instead of `text_column`) +2. **You have multiple datasets** and need to specify which dataset's columns to use +3. **Your dataset has system prompts or prefixes** in a separate column + +**Column mapping format:** The `--data-column-mapper` accepts a JSON string mapping column types to column names: + +```json +{ + "text_column": "question", + "prefix_column": "system_prompt", + "prompt_tokens_count_column": "input_tokens", + "output_tokens_count_column": "completion_tokens" +} +``` + +**Supported column types:** + +- `text_column`: The main prompt text (defaults: `prompt`, `instruction`, `question`, `input`, `context`, `content`, `text`) +- `prefix_column`: System prompt or prefix (defaults: `system_prompt`, `system`, `prefix`) +- `prompt_tokens_count_column`: Column containing prompt token counts (defaults: `prompt_tokens_count`, `input_tokens_count`) +- `output_tokens_count_column`: Column containing output token counts (defaults: `output_tokens_count`, `completion_tokens_count`) +- `image_column`: Image data column +- `video_column`: Video data column +- `audio_column`: Audio data column + +**Example: Mapping custom column names** + +If your dataset has a CSV file with columns `user_query` and `system_message`: + +```csv +user_query,system_message +"What is AI?","You are a helpful assistant." +"How does ML work?","You are a technical expert." +``` + +You would use: + +```bash +guidellm preprocess dataset \ + "dataset.csv" \ + "processed.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,output_tokens=256" \ + --data-column-mapper '{"text_column": "user_query", "prefix_column": "system_message"}' +``` + +**Example: Multiple datasets** + +If you're working with multiple datasets and need to specify which dataset's columns to use, you can use the format `.` or `.`: + +```bash +--data-column-mapper '{"text_column": "0.prompt", "prefix_column": "1.system"}' +``` + +### Handling Short Prompts + +When prompts are shorter than the target token length, you can specify how to handle them using `--short-prompt-strategy`: + +| Strategy | Description | +| ------------- | ------------------------------------------------------------------------------ | +| `ignore` | Skip prompts that are shorter than the target length (default) | +| `concatenate` | Concatenate multiple short prompts together until the target length is reached | +| `pad` | Pad short prompts with a specified character to reach the target length | +| `error` | Raise an error if a prompt is shorter than the target length | + +**Example: Concatenating short prompts** + +```bash +guidellm preprocess dataset \ + "dataset.jsonl" \ + "processed.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,output_tokens=256" \ + --short-prompt-strategy "concatenate" \ + --concat-delimiter "\n\n" +``` + +**Example: Padding short prompts** + +```bash +guidellm preprocess dataset \ + "dataset.jsonl" \ + "processed.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,output_tokens=256" \ + --short-prompt-strategy "pad" \ + --pad-char " " +``` + +### Additional Options + +| Option | Description | +| --------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- | +| `--data-args ` | JSON string of arguments to pass to dataset loading. See [Data Arguments Overview](../datasets.md#data-arguments-overview) for details. | +| `--include-prefix-in-token-count` | Include prefix tokens in prompt token count calculation (flag). When enabled, prefix trimming is disabled and the prefix is kept as-is. | +| `--random-seed ` | Random seed for reproducible token sampling (default: 42). | +| `--push-to-hub` | Push the processed dataset to Hugging Face Hub (flag). | +| `--hub-dataset-id ` | Hugging Face Hub dataset ID for upload (required if `--push-to-hub` is set). | + +### Complete Examples + +**Example 1: Basic preprocessing with custom column names** + +```bash +guidellm preprocess dataset \ + "my_dataset.csv" \ + "processed_dataset.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,output_tokens=256" \ + --data-column-mapper '{"text_column": "user_question", "prefix_column": "system_instruction"}' +``` + +**Example 2: Preprocessing with distribution and short prompt handling** + +```bash +guidellm preprocess dataset \ + "dataset.jsonl" \ + "processed.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,prompt_tokens_stdev=50,output_tokens=256,output_tokens_stdev=25" \ + --short-prompt-strategy "concatenate" \ + --concat-delimiter "\n\n" \ + --random-seed 123 +``` + +**Example 3: Preprocessing with processor arguments and prefix token limits** + +```bash +guidellm preprocess dataset \ + "dataset.jsonl" \ + "processed.jsonl" \ + --processor "gpt2" \ + --processor-args '{"use_fast": false}' \ + --config "prompt_tokens=512,output_tokens=256,prefix_tokens_max=100" \ + --include-prefix-in-token-count +``` + +**Example 4: Preprocessing and uploading to Hugging Face Hub** + +```bash +guidellm preprocess dataset \ + "my_dataset.jsonl" \ + "processed.jsonl" \ + --processor "gpt2" \ + --config "prompt_tokens=512,output_tokens=256" \ + --push-to-hub \ + --hub-dataset-id "username/processed-dataset" +``` + +### Notes + +- The `--config` parameter accepts a `PreprocessDatasetConfig` which includes all token count fields (prompt_tokens, output_tokens, etc.) plus `prefix_tokens_max` for controlling prefix length. See the [Configuration and Processor Options](#configuration-and-processor-options) section above for all available parameters. +- The processor/tokenizer is required because the preprocessing command needs to tokenize prompts to ensure they match target token sizes. See the [Data Arguments Overview](../datasets.md#data-arguments-overview) for processor usage details. +- Column mappings are only needed when your dataset uses non-standard column names. GuideLLM will automatically try common column names if no mapping is provided. +- When using `--short-prompt-strategy concatenate`, ensure your dataset has enough samples to concatenate, or some prompts may be skipped. +- The output format is determined by the file extension of `OUTPUT_PATH` (e.g., `.jsonl`, `.csv`, `.parquet`). +- The prefix handling only trims prefixes. It doesn't expand them. Use `prefix_tokens_max` in the config to set a maximum prefix length, which will trim prefixes that exceed this limit. diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index 2b52bbc5b..8506b0845 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -30,6 +30,8 @@ import click from pydantic import ValidationError +from guidellm.data import ShortPromptStrategy, process_dataset + try: import uvloop except ImportError: @@ -486,6 +488,142 @@ def preprocess(): """Dataset preprocessing utilities.""" +@preprocess.command( + "dataset", + help=( + "Process a dataset to have specific prompt and output token sizes. " + "Supports multiple strategies for handling prompts and optional " + "Hugging Face Hub upload.\n\n" + "DATA: Path to the input dataset or dataset ID.\n\n" + "OUTPUT_PATH: Path to save the processed dataset, including file suffix." + ), + context_settings={"auto_envvar_prefix": "GUIDELLM"}, +) +@click.argument( + "data", + type=str, + required=True, +) +@click.argument( + "output_path", + type=click.Path(file_okay=True, dir_okay=False, writable=True, resolve_path=True), + required=True, +) +@click.option( + "--processor", + type=str, + required=True, + help="Processor or tokenizer name for calculating token counts.", +) +@click.option( + "--config", + type=str, + required=True, + help=( + "PreprocessDatasetConfig as JSON string, key=value pairs, " + "or file path (.json, .yaml, .yml, .config). " + "Example: 'prompt_tokens=100,output_tokens=50,prefix_tokens_max=10'" + " or '{\"prompt_tokens\": 100, \"output_tokens\": 50, " + "\"prefix_tokens_max\": 10}'" + ), +) +@click.option( + "--processor-args", + default=None, + callback=cli_tools.parse_json, + help="JSON string of arguments to pass to the processor constructor.", +) +@click.option( + "--data-args", + callback=cli_tools.parse_json, + help="JSON string of arguments to pass to dataset creation.", +) +@click.option( + "--data-column-mapper", + default=None, + callback=cli_tools.parse_json, + help="JSON string of column mappings to apply to the dataset.", +) +@click.option( + "--short-prompt-strategy", + type=click.Choice([s.value for s in ShortPromptStrategy]), + default=ShortPromptStrategy.IGNORE.value, + show_default=True, + help="Strategy for handling prompts shorter than target length.", +) +@click.option( + "--pad-char", + type=str, + default="", + callback=decode_escaped_str, + help="Character to pad short prompts with when using 'pad' strategy.", +) +@click.option( + "--concat-delimiter", + type=str, + default="", + help=( + "Delimiter for concatenating short prompts (used with 'concatenate' strategy)." + ), +) +@click.option( + "--include-prefix-in-token-count", + is_flag=True, + default=False, + help="Include prefix tokens in prompt token count calculation.", +) +@click.option( + "--push-to-hub", + is_flag=True, + help="Push the processed dataset to Hugging Face Hub.", +) +@click.option( + "--hub-dataset-id", + type=str, + default=None, + help=("Hugging Face Hub dataset ID for upload (required if --push-to-hub is set)."), +) +@click.option( + "--random-seed", + type=int, + default=42, + show_default=True, + help="Random seed for reproducible token sampling.", +) +def dataset( + data, + output_path, + processor, + config, + processor_args, + data_args, + data_column_mapper, + short_prompt_strategy, + pad_char, + concat_delimiter, + include_prefix_in_token_count, + push_to_hub, + hub_dataset_id, + random_seed, +): + process_dataset( + data=data, + output_path=output_path, + processor=processor, + config=config, + processor_args=processor_args, + data_args=data_args, + data_column_mapper=data_column_mapper, + short_prompt_strategy=short_prompt_strategy, + pad_char=pad_char, + concat_delimiter=concat_delimiter, + include_prefix_in_token_count=include_prefix_in_token_count, + push_to_hub=push_to_hub, + hub_dataset_id=hub_dataset_id, + random_seed=random_seed, + ) + + @cli.command( "mock-server", help=( diff --git a/src/guidellm/data/__init__.py b/src/guidellm/data/__init__.py index 9adbd3c8d..482f3b8b0 100644 --- a/src/guidellm/data/__init__.py +++ b/src/guidellm/data/__init__.py @@ -4,6 +4,7 @@ DatasetDeserializer, DatasetDeserializerFactory, ) +from .entrypoints import ShortPromptStrategy, process_dataset from .loaders import DataLoader, DatasetsIterator from .preprocessors import ( DataDependentPreprocessor, @@ -27,4 +28,6 @@ "PreprocessorRegistry", "ProcessorFactory", "RequestFormatter", + "ShortPromptStrategy", + "process_dataset", ] diff --git a/src/guidellm/data/config.py b/src/guidellm/data/config.py new file mode 100644 index 000000000..2b0b2133a --- /dev/null +++ b/src/guidellm/data/config.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, TypeVar + +import yaml +from pydantic import ValidationError + +from guidellm.data.schemas import DataConfig, DataNotSupportedError + +ConfigT = TypeVar("ConfigT", bound=DataConfig) + + +def load_config(config: Any, config_class: type[ConfigT]) -> ConfigT | None: + # Try file path first + if (loaded_config := _load_config_file(config, config_class)) is not None: + return loaded_config + + # Try dict parsing next + if (loaded_config := _load_config_dict(config, config_class)) is not None: + return loaded_config + + # Try string parsing + if (loaded_config := _load_config_str(config, config_class)) is not None: + return loaded_config + + return None + + +def _load_config_dict(data: Any, config_class: type[ConfigT]) -> ConfigT | None: + if not isinstance(data, dict | list): + return None + + try: + return config_class.model_validate(data) + except ValidationError: + return None + + +def _load_config_file(data: Any, config_class: type[ConfigT]) -> ConfigT | None: + if (not isinstance(data, str) and not isinstance(data, Path)) or ( + not Path(data).is_file() + ): + return None + + data_path = Path(data) if isinstance(data, str) else data + error = None + + if Path(data).is_file() and data_path.suffix.lower() == ".json": + try: + return config_class.model_validate_json( + data_path.read_text() + ) + except Exception as err: # noqa: BLE001 + error = err + + if Path(data).is_file() and data_path.suffix.lower() in { + ".yaml", + ".yml", + ".config", + }: + try: + return config_class.model_validate( + yaml.safe_load(data_path.read_text()) + ) + except Exception as err: # noqa: BLE001 + error = err + + err_message = ( + f"Unsupported file {data_path} for " + f"{config_class.__name__}, expected .json, " + f".yaml, .yml, or .config" + ) + + if error is not None: + err_message += f" with error: {error}" + raise DataNotSupportedError(err_message) from error + raise DataNotSupportedError(err_message) + + +def _load_config_str(data: str, config_class: type[ConfigT]) -> ConfigT | None: + if not isinstance(data, str): + return None + + data_str = data.strip() + error = None + + if (data_str.startswith("{") and data_str.endswith("}")) or ( + data_str.startswith("[") and data_str.endswith("]") + ): + try: + return config_class.model_validate_json(data_str) + except Exception as err: # noqa: BLE001 + error = err + + if data_str.count("=") > 1: + # key=value pairs separated by commas + try: + config_dict = {} + items = data_str.split(",") + for item in items: + key, value = item.split("=") + config_dict[key.strip()] = ( + int(value.strip()) + if value.strip().isnumeric() + else value.strip() + ) + + return config_class.model_validate(config_dict) + except Exception as err: # noqa: BLE001 + error = err + + err_message = ( + f"Unsupported string data for {config_class.__name__}, " + f"expected JSON or key-value pairs, got {data}" + ) + if error is not None: + err_message += f" with error: {error}" + raise DataNotSupportedError(err_message) from error + raise DataNotSupportedError(err_message) diff --git a/src/guidellm/data/deserializers/__init__.py b/src/guidellm/data/deserializers/__init__.py index e24bfd5d7..fb22fd2a7 100644 --- a/src/guidellm/data/deserializers/__init__.py +++ b/src/guidellm/data/deserializers/__init__.py @@ -23,9 +23,7 @@ ) from .synthetic import ( SyntheticTextDataset, - SyntheticTextDatasetConfig, SyntheticTextDatasetDeserializer, - SyntheticTextPrefixBucketConfig, ) __all__ = [ @@ -45,9 +43,7 @@ "JSONFileDatasetDeserializer", "ParquetFileDatasetDeserializer", "SyntheticTextDataset", - "SyntheticTextDatasetConfig", "SyntheticTextDatasetDeserializer", - "SyntheticTextPrefixBucketConfig", "TarFileDatasetDeserializer", "TextFileDatasetDeserializer", ] diff --git a/src/guidellm/data/deserializers/deserializer.py b/src/guidellm/data/deserializers/deserializer.py index cddd48766..9ced2ff3a 100644 --- a/src/guidellm/data/deserializers/deserializer.py +++ b/src/guidellm/data/deserializers/deserializer.py @@ -6,20 +6,16 @@ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase +from guidellm.data.schemas import DataNotSupportedError from guidellm.data.utils import resolve_dataset_split from guidellm.utils import RegistryMixin __all__ = [ - "DataNotSupportedError", "DatasetDeserializer", "DatasetDeserializerFactory", ] -class DataNotSupportedError(Exception): - """Exception raised when data format is not supported by deserializer.""" - - @runtime_checkable class DatasetDeserializer(Protocol): def __call__( diff --git a/src/guidellm/data/deserializers/file.py b/src/guidellm/data/deserializers/file.py index 9819e1732..f46e0ae52 100644 --- a/src/guidellm/data/deserializers/file.py +++ b/src/guidellm/data/deserializers/file.py @@ -64,6 +64,7 @@ def __call__( **data_kwargs: dict[str, Any], ) -> Dataset: _ = (processor_factory, random_seed) + if ( not isinstance(data, str | Path) or not (path := Path(data)).exists() @@ -72,7 +73,7 @@ def __call__( ): raise DataNotSupportedError( "Unsupported data for CSVFileDatasetDeserializer, " - f"expected str or Path to a local .csv file, got {data}" + f"expected str or Path to a valid local .csv file, got {data}" ) return load_dataset("csv", data_files=str(path), **data_kwargs) diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index c8ec2831c..068ec78e2 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -2,127 +2,30 @@ import math from collections.abc import Callable, Iterator -from pathlib import Path from random import Random from typing import Any import numpy as np -import yaml from datasets import DatasetInfo, Features, IterableDataset, Value from datasets.iterable_dataset import _BaseExamplesIterable from faker import Faker -from pydantic import ConfigDict, Field, ValidationError, model_validator from transformers import PreTrainedTokenizerBase +from guidellm.data.config import load_config from guidellm.data.deserializers.deserializer import ( DataNotSupportedError, DatasetDeserializer, DatasetDeserializerFactory, ) -from guidellm.schemas import StandardBaseModel +from guidellm.data.schemas import SyntheticTextDatasetConfig from guidellm.utils import IntegerRangeSampler __all__ = [ "SyntheticTextDataset", - "SyntheticTextDatasetConfig", "SyntheticTextDatasetDeserializer", - "SyntheticTextPrefixBucketConfig", ] -class SyntheticTextPrefixBucketConfig(StandardBaseModel): - bucket_weight: int = Field( - description="Weight of this bucket in the overall distribution.", - gt=0, - default=100, - ) - prefix_count: int = Field( - description="The number of unique prefixes to generate for this bucket.", - ge=1, - default=1, - ) - prefix_tokens: int = Field( - description="The number of prefix tokens per-prompt for this bucket.", - ge=0, - default=0, - ) - - -class SyntheticTextDatasetConfig(StandardBaseModel): - model_config = ConfigDict( - extra="allow", - ) - - prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field( - description="Buckets for the prefix tokens distribution.", - default=None, - ) - prompt_tokens: int = Field( - description="The average number of text tokens generated for prompts.", - gt=0, - ) - prompt_tokens_stdev: int | None = Field( - description="The standard deviation of the tokens generated for prompts.", - gt=0, - default=None, - ) - prompt_tokens_min: int | None = Field( - description="The minimum number of text tokens generated for prompts.", - gt=0, - default=None, - ) - prompt_tokens_max: int | None = Field( - description="The maximum number of text tokens generated for prompts.", - gt=0, - default=None, - ) - output_tokens: int = Field( - description="The average number of text tokens generated for outputs.", - gt=0, - ) - output_tokens_stdev: int | None = Field( - description="The standard deviation of the tokens generated for outputs.", - gt=0, - default=None, - ) - output_tokens_min: int | None = Field( - description="The minimum number of text tokens generated for outputs.", - gt=0, - default=None, - ) - output_tokens_max: int | None = Field( - description="The maximum number of text tokens generated for outputs.", - gt=0, - default=None, - ) - source: str = Field( - description="The source of the text data to be used for generation.", - default="data:prideandprejudice.txt.gz", - ) - - @model_validator(mode="after") - def check_prefix_options(self) -> SyntheticTextDatasetConfig: - if self.__pydantic_extra__ is not None: - prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] - prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] - - if prefix_count is not None or prefix_tokens is not None: - if self.prefix_buckets: - raise ValueError( - "prefix_buckets is mutually exclusive" - " with prefix_count and prefix_tokens" - ) - - self.prefix_buckets = [ - SyntheticTextPrefixBucketConfig( - prefix_count=prefix_count or 1, - prefix_tokens=prefix_tokens or 0, - ) - ] - - return self - - class _SyntheticTextExamplesIterable(_BaseExamplesIterable): """Custom examples iterable for synthetic text generation.""" @@ -325,16 +228,8 @@ def __call__( random_seed: int, **data_kwargs: dict[str, Any], ) -> IterableDataset: - # Config file pathways, deserialize and call self again - if (config := self._load_config_file(data)) is not None: - return self(config, processor_factory, random_seed, **data_kwargs) - - # Config str pathways, deserialize and call self again - if (config := self._load_config_str(data)) is not None: - return self(config, processor_factory, random_seed, **data_kwargs) - - # Try to parse dict-like data directly - if (config := self._load_config_dict(data)) is not None: + # Config file and string pathways; deserialize and call self again + if (config := load_config(data, SyntheticTextDatasetConfig)) is not None: return self(config, processor_factory, random_seed, **data_kwargs) if not isinstance(data, SyntheticTextDatasetConfig): @@ -349,93 +244,3 @@ def __call__( processor=processor_factory(), random_seed=random_seed, ) - - def _load_config_dict(self, data: Any) -> SyntheticTextDatasetConfig | None: - if not isinstance(data, dict | list): - return None - - try: - return SyntheticTextDatasetConfig.model_validate(data) - except ValidationError: - return None - - def _load_config_file(self, data: Any) -> SyntheticTextDatasetConfig | None: - if (not isinstance(data, str) and not isinstance(data, Path)) or ( - not Path(data).is_file() - ): - return None - - data_path = Path(data) if isinstance(data, str) else data - error = None - - if Path(data).is_file() and data_path.suffix.lower() == ".json": - try: - return SyntheticTextDatasetConfig.model_validate_json( - data_path.read_text() - ) - except Exception as err: # noqa: BLE001 - error = err - - if Path(data).is_file() and data_path.suffix.lower() in { - ".yaml", - ".yml", - ".config", - }: - try: - return SyntheticTextDatasetConfig.model_validate( - yaml.safe_load(data_path.read_text()) - ) - except Exception as err: # noqa: BLE001 - error = err - - err_message = ( - f"Unsupported file {data_path} for " - f"SyntheticTextDatasetDeserializer, expected .json, " - f".yaml, .yml, or .config" - ) - - if error is not None: - err_message += f" with error: {error}" - raise DataNotSupportedError(err_message) from error - raise DataNotSupportedError(err_message) - - def _load_config_str(self, data: str) -> SyntheticTextDatasetConfig | None: - if not isinstance(data, str): - return None - - data_str = data.strip() - error = None - - if (data_str.startswith("{") and data_str.endswith("}")) or ( - data_str.startswith("[") and data_str.endswith("]") - ): - try: - return SyntheticTextDatasetConfig.model_validate_json(data_str) - except Exception as err: # noqa: BLE001 - error = err - - if data_str.count("=") > 1: - # key=value pairs separated by commas - try: - config_dict = {} - items = data_str.split(",") - for item in items: - key, value = item.split("=") - config_dict[key.strip()] = ( - int(value.strip()) - if value.strip().isnumeric() - else value.strip() - ) - - return SyntheticTextDatasetConfig.model_validate(config_dict) - except Exception as err: # noqa: BLE001 - error = err - - err_message = ( - "Unsupported string data for SyntheticTextDatasetDeserializer, " - f"expected JSON or key-value pairs, got {data}" - ) - if error is not None: - err_message += f" with error: {error}" - raise DataNotSupportedError(err_message) from error - raise DataNotSupportedError(err_message) diff --git a/src/guidellm/data/entrypoints.py b/src/guidellm/data/entrypoints.py new file mode 100644 index 000000000..98cf215d0 --- /dev/null +++ b/src/guidellm/data/entrypoints.py @@ -0,0 +1,560 @@ +import os +from collections.abc import Callable, Iterator +from enum import Enum +from pathlib import Path +from typing import Any, cast + +from datasets import Dataset +from loguru import logger +from transformers import PreTrainedTokenizerBase + +from guidellm.data.config import load_config +from guidellm.data.deserializers import ( + DatasetDeserializerFactory, +) +from guidellm.data.preprocessors import GenerativeColumnMapper +from guidellm.data.schemas import PreprocessDatasetConfig +from guidellm.utils import IntegerRangeSampler, check_load_processor +from guidellm.utils.hf_datasets import SUPPORTED_TYPES, save_dataset_to_file + + +class PromptTooShortError(Exception): + pass + + +class ShortPromptStrategy(str, Enum): + IGNORE = "ignore" + CONCATENATE = "concatenate" + PAD = "pad" + ERROR = "error" + + +class ShortPromptStrategyHandler: + """Handler class for short prompt strategies.""" + + @staticmethod + def handle_ignore( + current_prompt: str, + min_prompt_tokens: int, + tokenizer: PreTrainedTokenizerBase, + **_kwargs, + ) -> str | None: + """ + Ignores prompts that are shorter than the required minimum token length. + + :param current_prompt: The input prompt string. + :param min_prompt_tokens: Minimum required token count. + :param tokenizer: Tokenizer used to count tokens. + :return: The prompt if it meets the length, otherwise None. + """ + + if len(tokenizer.encode(current_prompt)) < min_prompt_tokens: + logger.warning("Prompt too short, ignoring") + return None + return current_prompt + + @staticmethod + def handle_concatenate( + current_prompt: str, + min_prompt_tokens: int, + dataset_iterator: Iterator[dict[str, Any]], + prompt_column: str, + tokenizer: PreTrainedTokenizerBase, + concat_delimiter: str, + **_kwargs, + ) -> str | None: + """ + Concatenates prompts until the minimum token requirement is met. + + :param current_prompt: The initial prompt. + :param min_prompt_tokens: Target minimum token length. + :param dataset_iterator: Iterator to fetch more prompts. + :param prompt_column: Column key for prompt extraction. + :param tokenizer: Tokenizer used to count tokens. + :param concat_delimiter: Delimiter to use between prompts. + :return: Concatenated prompt or None if not enough data. + """ + + tokens_len = len(tokenizer.encode(current_prompt)) + while tokens_len < min_prompt_tokens: + try: + next_row = next(dataset_iterator) + except StopIteration: + logger.warning( + "Could not concatenate enough prompts to reach minimum " + "length, ignoring" + ) + return None + current_prompt += concat_delimiter + next_row[prompt_column] + tokens_len = len(tokenizer.encode(current_prompt)) + return current_prompt + + @staticmethod + def handle_pad( + current_prompt: str, + min_prompt_tokens: int, + tokenizer: PreTrainedTokenizerBase, + pad_char: str, + pad_multiplier: int = 2, + **_kwargs, + ) -> str: + """ + Pads the prompt with a character until it reaches the minimum token length. + + :param current_prompt: The input prompt. + :param min_prompt_tokens: Desired minimum token count. + :param tokenizer: Tokenizer used to count tokens. + :param pad_char: Character used for padding. + :param pad_multiplier: Multiplier for padding character length. + :return: Padded prompt string. + """ + tokens = tokenizer.encode(current_prompt) + pad_count = 1 + prompt = current_prompt + while len(tokens) < min_prompt_tokens: + prompt += pad_char * pad_count + tokens = tokenizer.encode(prompt) + pad_count *= pad_multiplier + return prompt + + @staticmethod + def handle_error( + current_prompt: str, + min_prompt_tokens: int, + tokenizer: PreTrainedTokenizerBase, + **_kwargs, + ) -> str | None: + """ + Raises an error if the prompt is too short. + + :param current_prompt: The input prompt. + :param min_prompt_tokens: Required token count. + :param tokenizer: Tokenizer used to count tokens. + :return: The input prompt if valid. + :raises PromptTooShortError: If the prompt is too short. + """ + + prompt_len = len(tokenizer.encode(current_prompt)) + if prompt_len < min_prompt_tokens: + raise PromptTooShortError( + f"Found too short prompt: {current_prompt}, with length: {prompt_len}. " + f"Minimum length required: {min_prompt_tokens}.", + ) + return current_prompt + + @classmethod + def get_strategy_handler(cls, strategy: ShortPromptStrategy) -> Callable[..., Any]: + """ + Get the handler for a specific strategy. + + :param strategy: The short prompt strategy to get the handler for. + :return: The handler callable for the specified strategy. + """ + return cast("Callable[..., Any]", STRATEGY_HANDLERS[strategy]) + + +# Initialize STRATEGY_HANDLERS after class definition to allow method references +STRATEGY_HANDLERS = { + ShortPromptStrategy.IGNORE: ShortPromptStrategyHandler.handle_ignore, + ShortPromptStrategy.CONCATENATE: ShortPromptStrategyHandler.handle_concatenate, + ShortPromptStrategy.PAD: ShortPromptStrategyHandler.handle_pad, + ShortPromptStrategy.ERROR: ShortPromptStrategyHandler.handle_error, +} + + +def _validate_output_suffix(output_path: str | Path) -> None: + output_path = Path(output_path) + suffix = output_path.suffix.lower() + if suffix not in SUPPORTED_TYPES: + raise ValueError( + f"Unsupported file suffix '{suffix}' in output_path '{output_path}'. " + f"Only {SUPPORTED_TYPES} are supported." + ) + + +def parse_synthetic_config( + config_input: str | Path, +) -> PreprocessDatasetConfig: + """ + Parse PreprocessDatasetConfig from string or file path. + + Reuses SyntheticTextDatasetDeserializer's parsing logic to support: + - JSON strings + - Key=value pairs + - File paths (.json, .yaml, .yml, .config) + + :param config_input: String or path to config. + :return: Parsed PreprocessDatasetConfig instance. + :raises ValueError: If the format is not recognized or parsing fails. + """ + config = load_config(config_input, PreprocessDatasetConfig) + + if config is not None: + return config + + raise ValueError( + f"Could not parse config from input: {config_input}. " + "Expected JSON string, key=value pairs, or file path " + "(.json, .yaml, .yml, .config)" + ) + + +def process_dataset( + data: str | Path, + output_path: str | Path, + processor: str | Path | PreTrainedTokenizerBase, + config: str | Path, + processor_args: dict[str, Any] | None = None, + data_args: dict[str, Any] | None = None, + data_column_mapper: dict[str, str] | None = None, + short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE, + pad_char: str | None = None, + concat_delimiter: str | None = None, + include_prefix_in_token_count: bool = False, + push_to_hub: bool = False, + hub_dataset_id: str | None = None, + random_seed: int = 42, +) -> None: + """ + Main method to process and save a dataset with sampled prompt/output token counts. + + :param data: Path or identifier for dataset input. + :param output_path: File path to save the processed dataset. + :param processor: Tokenizer object or its config. + :param config: PreprocessDatasetConfig string or file path. + :param processor_args: Optional processor arguments. + :param data_args: Optional data loading arguments. + :param data_column_mapper: Optional column mapping dictionary. + :param short_prompt_strategy: Strategy for handling short prompts. + :param pad_char: Character used when padding short prompts. + :param concat_delimiter: Delimiter for concatenation strategy. + :param include_prefix_in_token_count: + Whether to include prefix in prompt token count, simplifying the token counts. + When True, prefix trimming is disabled and the prefix is kept as-is. The prefix + token count is subtracted from the prompt token budget instead. + :param push_to_hub: Whether to push to Hugging Face Hub. + :param hub_dataset_id: Dataset ID on Hugging Face Hub. + :param random_seed: Seed for random sampling. + :raises ValueError: If the output path is invalid or pushing conditions unmet. + """ + _validate_output_suffix(output_path) + logger.info( + f"Starting dataset conversion | Input: {data} | Output: {output_path}" + ) + + # Parse config + config_obj = parse_synthetic_config(config) + + # Load tokenizer + tokenizer = check_load_processor( + processor, + processor_args, + "dataset conversion.", + ) + + # Load dataset + dataset = DatasetDeserializerFactory.deserialize( + data=data, + processor_factory=lambda: tokenizer, + random_seed=random_seed, + **(data_args or {}), + ) + + # Setup column mapper + column_mapper = GenerativeColumnMapper( + column_mappings=data_column_mapper # type: ignore[arg-type] + ) + column_mapper.setup_data( + datasets=[dataset], + data_args=[data_args or {}], + ) + + # Extract column names from mapper + prompt_column, prefix_column, output_column = _extract_column_names(column_mapper) + + # Create token samplers + prompt_token_sampler, output_token_sampler, prefix_tokens_max = ( + _create_token_samplers( + config_obj, + random_seed, + ) + ) + + # Process dataset + dataset_iterator = iter(dataset) + processed_prompts = [] + prompt_handler = ShortPromptStrategyHandler.get_strategy_handler( + short_prompt_strategy + ) + + for row in dataset_iterator: + processed_row = _process_single_row( + row=row, + prompt_column=prompt_column, + prefix_column=prefix_column, + prompt_token_sampler=prompt_token_sampler, + output_token_sampler=output_token_sampler, + tokenizer=tokenizer, + prompt_handler=prompt_handler, + dataset_iterator=dataset_iterator, + include_prefix_in_token_count=include_prefix_in_token_count, + pad_char=pad_char, + concat_delimiter=concat_delimiter, + output_column=output_column, + prefix_tokens_max=prefix_tokens_max, + ) + if processed_row is not None: + processed_prompts.append(processed_row) + + # Finalize + _finalize_processed_dataset( + processed_prompts, + output_path, + push_to_hub, + hub_dataset_id, + ) + + +def _extract_column_names( + column_mapper: GenerativeColumnMapper, +) -> tuple[str, str | None, str]: + """ + Extract column names for prompt, prefix, and output from column mapper. + + :param column_mapper: Initialized column mapper. + :return: Tuple of (prompt_column, prefix_column, output_column). + :raises ValueError: If column mapper is not properly initialized. + """ + if column_mapper.datasets_column_mappings is None: + raise ValueError("Column mapper not properly initialized") + + text_mappings = column_mapper.datasets_column_mappings.get("text_column", []) + if not text_mappings: + raise ValueError("Could not find text column in dataset") + prompt_column = text_mappings[0][1] + + prefix_mappings = column_mapper.datasets_column_mappings.get("prefix_column", []) + prefix_column = prefix_mappings[0][1] if prefix_mappings else None + + output_mappings = column_mapper.datasets_column_mappings.get( + "output_tokens_count_column", [] + ) + output_column = ( + output_mappings[0][1] if output_mappings else "output_tokens_count" + ) + + return prompt_column, prefix_column, output_column + + +def _create_token_samplers( + config_obj: PreprocessDatasetConfig, + random_seed: int, +) -> tuple[Iterator[int], Iterator[int], int | None]: + """ + Create token samplers for prompt, output, and prefix tokens. + + :param config_obj: Configuration object with token settings. + :param prefix_tokens: Optional single prefix token count. + :param random_seed: Seed for random sampling. + :return: Tuple of (prompt_sampler, output_sampler, prefix_tokens_max). + prefix_sampler is None when prefix_tokens is not provided. + prefix_tokens_max is the maximum prefix token limit from config. + """ + prompt_token_sampler = iter( + IntegerRangeSampler( + average=config_obj.prompt_tokens, + variance=config_obj.prompt_tokens_stdev, + min_value=config_obj.prompt_tokens_min, + max_value=config_obj.prompt_tokens_max, + random_seed=random_seed, + ) + ) + + output_token_sampler = iter( + IntegerRangeSampler( + average=config_obj.output_tokens, + variance=config_obj.output_tokens_stdev, + min_value=config_obj.output_tokens_min, + max_value=config_obj.output_tokens_max, + random_seed=random_seed, + ) + ) + + return prompt_token_sampler, output_token_sampler, config_obj.prefix_tokens_max + + +def _process_dataset_row( + row: dict[str, Any], + prompt_column: str, + prefix_column: str | None, + output_column: str, + target_output_len: int, + prompt_text: str, + prefix_text: str | None, + tokens: list[int], +) -> dict[str, Any]: + """ + Create a processed row from the processed prompt/prefix data. + + :param row: Original dataset row. + :param prompt_column: Name of prompt column. + :param prefix_column: Name of prefix column or None. + :param output_column: Name of output tokens count column. + :param target_prompt_len: Target prompt token length. + :param target_output_len: Target output token length. + :param prompt_text: Processed prompt text. + :param prefix_text: Processed prefix text or None. + :param tokens: Tokenized prompt. + :return: Processed row dictionary. + """ + processed_row = row.copy() + processed_row[prompt_column] = prompt_text + if prefix_column and prefix_text: + processed_row[prefix_column] = prefix_text + processed_row["prompt_tokens_count"] = len(tokens) + processed_row[output_column] = target_output_len + return processed_row + + +def _process_single_row( + row: dict[str, Any], + prompt_column: str, + prefix_column: str | None, + prompt_token_sampler: Iterator[int], + output_token_sampler: Iterator[int], + tokenizer: PreTrainedTokenizerBase, + prompt_handler: Callable, + dataset_iterator: Iterator[dict[str, Any]], + include_prefix_in_token_count: bool, + pad_char: str | None, + concat_delimiter: str | None, + output_column: str, + prefix_tokens_max: int | None, +) -> dict[str, Any] | None: + """ + Process a single row from the dataset. + + :param include_prefix_in_token_count: When True, includes prefix tokens in the + prompt token count calculation. When False, prefix tokens are not counted + toward prompt tokens. + :param prefix_tokens_max: Maximum prefix token limit. If set, the prefix will be + trimmed if it exceeds this limit. + :return: Processed row dictionary or None if row should be skipped. + """ + # Extract prompt and prefix + prompt_text = row.get(prompt_column, "") + prefix_text = row.get(prefix_column) if prefix_column else None + + # Sample target prompt token count + target_prompt_len = next(prompt_token_sampler) + count_adjustment = 0 + + # Handle prefix + if prefix_text: + # Apply prefix_tokens_max limit if set (strict maximum) + if prefix_tokens_max is not None: + prefix_tokens_list = tokenizer.encode(prefix_text) + if len(prefix_tokens_list) > prefix_tokens_max: + prefix_text = tokenizer.decode( + prefix_tokens_list[:prefix_tokens_max] + ) + + # Count prefix tokens toward prompt if enabled + if include_prefix_in_token_count: + count_adjustment = len(tokenizer.encode(prefix_text)) + + if target_prompt_len == 0: + logger.warning("zero prompt size requested; skipping row") + return None + elif count_adjustment > 0: + adjusted_prompt_len = target_prompt_len - count_adjustment + if adjusted_prompt_len <= 0: + logger.warning("The prefix exceeds target output length with " + "--include-prefix-in-token-count enabled; Using prompt size" + "of 1; skipping row") + return None + target_prompt_len = adjusted_prompt_len + + # Handle short prompts + prompt_text = prompt_handler( + current_prompt=prompt_text, + min_prompt_tokens=target_prompt_len, + dataset_iterator=dataset_iterator, + prompt_column=prompt_column, + tokenizer=tokenizer, + pad_char=pad_char, + concat_delimiter=concat_delimiter, + ) + if prompt_text is None: + return None + + # Trim long prompts + tokens = tokenizer.encode(prompt_text) + if len(tokens) > target_prompt_len: + prompt_text = tokenizer.decode(tokens[:target_prompt_len]) + tokens = tokenizer.encode(prompt_text) + + # Sample output token count + target_output_len = next(output_token_sampler) + + # Create processed row + return _process_dataset_row( + row=row, + prompt_column=prompt_column, + prefix_column=prefix_column, + output_column=output_column, + target_output_len=target_output_len, + prompt_text=prompt_text, + prefix_text=prefix_text, + tokens=tokens, + ) + + +def _finalize_processed_dataset( + processed_prompts: list[dict[str, Any]], + output_path: str | Path, + push_to_hub: bool, + hub_dataset_id: str | None, +) -> None: + """ + Finalize the processed dataset by saving and optionally pushing to hub. + + :param processed_prompts: List of processed row dictionaries. + :param output_path: Path to save the dataset. + :param push_to_hub: Whether to push to Hugging Face Hub. + :param hub_dataset_id: Dataset ID on Hugging Face Hub. + """ + if not processed_prompts: + logger.error("No prompts remained after processing") + return + + logger.info(f"Generated processed dataset with {len(processed_prompts)} prompts") + + processed_dataset = Dataset.from_list(processed_prompts) + save_dataset_to_file(processed_dataset, output_path) + logger.info(f"Conversion completed. Dataset saved to: {output_path}") + + if push_to_hub: + push_dataset_to_hub(hub_dataset_id, processed_dataset) + logger.info(f"Pushed dataset to: {hub_dataset_id}") + + +def push_dataset_to_hub( + hub_dataset_id: str | None, + processed_dataset: Dataset, +) -> None: + """ + Pushes the processed dataset to Hugging Face Hub using HF_TOKEN. + + :param hub_dataset_id: Identifier on the Hub to push to. + :param processed_dataset: HuggingFace Dataset object. + :raises ValueError: If hub_dataset_id or HF_TOKEN is not available. + """ + + hf_token = os.environ.get("HF_TOKEN") + if not hub_dataset_id or not hf_token: + raise ValueError( + "hub_dataset_id and HF_TOKEN env var must be provided when push_to_hub" + " is True" + ) + processed_dataset.push_to_hub(hub_dataset_id, token=hf_token) diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py index c4421e073..16af56dff 100644 --- a/src/guidellm/data/schemas.py +++ b/src/guidellm/data/schemas.py @@ -1,6 +1,19 @@ +from __future__ import annotations + from typing import Literal -__all__ = ["GenerativeDatasetColumnType"] +from pydantic import ConfigDict, Field, model_validator + +from guidellm.schemas import StandardBaseModel + +__all__ = [ + "DataConfig", + "DataNotSupportedError", + "GenerativeDatasetColumnType", + "SyntheticTextDatasetConfig", + "SyntheticTextPrefixBucketConfig", +] + GenerativeDatasetColumnType = Literal[ "prompt_tokens_count_column", @@ -11,3 +24,152 @@ "video_column", "audio_column", ] + +class DataNotSupportedError(Exception): + """ + Exception raised when the data format is not supported by deserializer or config. + """ + +class DataConfig(StandardBaseModel): + """ + A generic parent class for various configs for the data package + that can be passed in as key-value pairs or JSON. + """ + +class PreprocessDatasetConfig(DataConfig): + + prompt_tokens: int = Field( + description="The average number of text tokens retained or added to prompts.", + gt=0, + ) + prompt_tokens_stdev: int | None = Field( + description="The standard deviation of the number of tokens retained in or " + "added to prompts.", + gt=0, + default=None, + ) + prompt_tokens_min: int | None = Field( + description="The minimum number of text tokens retained or added to prompts.", + gt=0, + default=None, + ) + prompt_tokens_max: int | None = Field( + description="The maximum number of text tokens retained or added to prompts.", + gt=0, + default=None, + ) + output_tokens: int = Field( + description="The average number of text tokens retained or added to outputs.", + gt=0, + ) + output_tokens_stdev: int | None = Field( + description="The standard deviation of the number of tokens retained or " + "added to outputs.", + gt=0, + default=None, + ) + output_tokens_min: int | None = Field( + description="The minimum number of text tokens retained or added to outputs.", + gt=0, + default=None, + ) + output_tokens_max: int | None = Field( + description="The maximum number of text tokens retained or added to outputs.", + gt=0, + default=None, + ) + prefix_tokens_max: int | None = Field( + description="The maximum number of text tokens left in the prefixes.", + gt=0, + default=None, + ) + +class SyntheticTextPrefixBucketConfig(StandardBaseModel): + bucket_weight: int = Field( + description="Weight of this bucket in the overall distribution.", + gt=0, + default=100, + ) + prefix_count: int = Field( + description="The number of unique prefixes to generate for this bucket.", + ge=1, + default=1, + ) + prefix_tokens: int = Field( + description="The number of prefix tokens per-prompt for this bucket.", + ge=0, + default=0, + ) + + +class SyntheticTextDatasetConfig(DataConfig): + prompt_tokens: int = Field( + description="The average number of text tokens generated for prompts.", + gt=0, + ) + prompt_tokens_stdev: int | None = Field( + description="The standard deviation of the tokens generated for prompts.", + gt=0, + default=None, + ) + prompt_tokens_min: int | None = Field( + description="The minimum number of text tokens generated for prompts.", + gt=0, + default=None, + ) + prompt_tokens_max: int | None = Field( + description="The maximum number of text tokens generated for prompts.", + gt=0, + default=None, + ) + output_tokens: int = Field( + description="The average number of text tokens generated for outputs.", + gt=0, + ) + output_tokens_stdev: int | None = Field( + description="The standard deviation of the tokens generated for outputs.", + gt=0, + default=None, + ) + output_tokens_min: int | None = Field( + description="The minimum number of text tokens generated for outputs.", + gt=0, + default=None, + ) + output_tokens_max: int | None = Field( + description="The maximum number of text tokens generated for outputs.", + gt=0, + default=None, + ) + + model_config = ConfigDict( + extra="allow", + ) + + prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field( + description="Buckets for the prefix tokens distribution.", + default=None, + ) + + + @model_validator(mode="after") + def check_prefix_options(self) -> SyntheticTextDatasetConfig: + if self.__pydantic_extra__ is not None: + prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] + + if prefix_count is not None or prefix_tokens is not None: + if self.prefix_buckets: + raise ValueError( + "prefix_buckets is mutually exclusive" + " with prefix_count and prefix_tokens" + ) + + self.prefix_buckets = [ + SyntheticTextPrefixBucketConfig( + prefix_count=prefix_count or 1, + prefix_tokens=prefix_tokens or 0, + ) + ] + + return self diff --git a/src/guidellm/preprocess/__init__.py b/src/guidellm/preprocess/__init__.py deleted file mode 100644 index 95d01e5f3..000000000 --- a/src/guidellm/preprocess/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .dataset import ShortPromptStrategy, process_dataset - -__all__ = ["ShortPromptStrategy", "process_dataset"] diff --git a/src/guidellm/preprocess/dataset.py b/src/guidellm/preprocess/dataset.py deleted file mode 100644 index 033bf106b..000000000 --- a/src/guidellm/preprocess/dataset.py +++ /dev/null @@ -1,371 +0,0 @@ -import json -import os -from collections.abc import Callable, Iterator -from enum import Enum -from pathlib import Path -from typing import Any - -import yaml -from datasets import Dataset -from loguru import logger -from pydantic import BaseModel, Field -from transformers import PreTrainedTokenizerBase - -from guidellm.utils import IntegerRangeSampler, check_load_processor -from guidellm.utils.hf_datasets import SUPPORTED_TYPES, save_dataset_to_file - - -class PromptTooShortError(Exception): - pass - - -class ShortPromptStrategy(str, Enum): - IGNORE = "ignore" - CONCATENATE = "concatenate" - PAD = "pad" - ERROR = "error" - - -def handle_ignore_strategy( - current_prompt: str, - min_prompt_tokens: int, - tokenizer: PreTrainedTokenizerBase, - **_kwargs, -) -> str | None: - """ - Ignores prompts that are shorter than the required minimum token length. - - :param current_prompt: The input prompt string. - :param min_prompt_tokens: Minimum required token count. - :param tokenizer: Tokenizer used to count tokens. - :return: The prompt if it meets the length, otherwise None. - """ - - if len(tokenizer.encode(current_prompt)) < min_prompt_tokens: - logger.warning("Prompt too short, ignoring") - return None - return current_prompt - - -def handle_concatenate_strategy( - current_prompt: str, - min_prompt_tokens: int, - dataset_iterator: Iterator[dict[str, Any]], - prompt_column: str, - tokenizer: PreTrainedTokenizerBase, - concat_delimiter: str, - **_kwargs, -) -> str | None: - """ - Concatenates prompts until the minimum token requirement is met. - - :param current_prompt: The initial prompt. - :param min_prompt_tokens: Target minimum token length. - :param dataset_iterator: Iterator to fetch more prompts. - :param prompt_column: Column key for prompt extraction. - :param tokenizer: Tokenizer used to count tokens. - :param concat_delimiter: Delimiter to use between prompts. - :return: Concatenated prompt or None if not enough data. - """ - - tokens_len = len(tokenizer.encode(current_prompt)) - while tokens_len < min_prompt_tokens: - try: - next_row = next(dataset_iterator) - except StopIteration: - logger.warning( - "Could not concatenate enough prompts to reach minimum length, ignoring" - ) - return None - current_prompt += concat_delimiter + next_row[prompt_column] - tokens_len = len(tokenizer.encode(current_prompt)) - return current_prompt - - -def handle_pad_strategy( - current_prompt: str, - min_prompt_tokens: int, - tokenizer: PreTrainedTokenizerBase, - pad_char: str, - pad_multiplier: int = 2, - **_kwargs, -) -> str: - """ - Pads the prompt with a character until it reaches the minimum token length. - - :param current_prompt: The input prompt. - :param min_prompt_tokens: Desired minimum token count. - :param tokenizer: Tokenizer used to count tokens. - :param pad_char: Character used for padding. - :param pad_multiplier: Multiplier for padding character length. - :return: Padded prompt string. - """ - - tokens = tokenizer.encode(current_prompt) - pad_count = 1 - prompt = current_prompt - while len(tokens) < min_prompt_tokens: - prompt += pad_char * pad_count - tokens = tokenizer.encode(prompt) - pad_count *= pad_multiplier - return prompt - - -def handle_error_strategy( - current_prompt: str, - min_prompt_tokens: int, - tokenizer: PreTrainedTokenizerBase, - **_kwargs, -) -> str | None: - """ - Raises an error if the prompt is too short. - - :param current_prompt: The input prompt. - :param min_prompt_tokens: Required token count. - :param tokenizer: Tokenizer used to count tokens. - :return: The input prompt if valid. - :raises PromptTooShortError: If the prompt is too short. - """ - - prompt_len = len(tokenizer.encode(current_prompt)) - if prompt_len < min_prompt_tokens: - raise PromptTooShortError( - f"Found too short prompt: {current_prompt}, with length: {prompt_len}. " - f"Minimum length required: {min_prompt_tokens}.", - ) - return current_prompt - - -STRATEGY_HANDLERS: dict[ShortPromptStrategy, Callable] = { - ShortPromptStrategy.IGNORE: handle_ignore_strategy, - ShortPromptStrategy.CONCATENATE: handle_concatenate_strategy, - ShortPromptStrategy.PAD: handle_pad_strategy, - ShortPromptStrategy.ERROR: handle_error_strategy, -} - - -class TokensConfig(BaseModel): - average: int = Field( - description="The average number of tokens.", - gt=0, - ) - stdev: int | None = Field( - description="The standard deviation of the tokens.", - gt=0, - default=None, - ) - min: int | None = Field( - description="The minimum number of tokens.", - gt=0, - default=None, - ) - max: int | None = Field( - description="The maximum number of tokens.", - gt=0, - default=None, - ) - - @staticmethod - def parse_str(data: str | Path) -> "TokensConfig": - """ - Parses a string or path into a TokensConfig object. Supports: - - JSON string - - key=value pairs - - file path to .yaml/.config - - :param data: String or path containing configuration. - :return: Parsed TokensConfig instance. - :raises ValueError: If the format is not recognized. - """ - - if ( - isinstance(data, Path) - or data.strip().endswith(".config") - or data.strip().endswith(".yaml") - ): - return TokensConfig.parse_config_file(data) - - if data.strip().startswith("{"): - return TokensConfig.parse_json(data) - - if data.count("=") > 1: - return TokensConfig.parse_key_value_pairs(data) - - raise ValueError( - f"Unsupported data format. Expected JSON or key-value pairs, got {data}" - ) - - @staticmethod - def parse_json(data: str) -> "TokensConfig": - config_dict = json.loads(data.strip()) - - return TokensConfig(**config_dict) - - @staticmethod - def parse_key_value_pairs(data: str) -> "TokensConfig": - config_dict = {} - items = data.strip().split(",") - for item in items: - key, value = item.split("=") - config_dict[key.strip()] = ( - int(value.strip()) if value.strip().isnumeric() else value.strip() - ) - - return TokensConfig(**config_dict) # type: ignore[arg-type] - - @staticmethod - def parse_config_file(data: str | Path) -> "TokensConfig": - with Path(data).open("r") as file: - config_dict = yaml.safe_load(file) - - return TokensConfig(**config_dict) - - -def _validate_output_suffix(output_path: str | Path) -> None: - output_path = Path(output_path) - suffix = output_path.suffix.lower() - if suffix not in SUPPORTED_TYPES: - raise ValueError( - f"Unsupported file suffix '{suffix}' in output_path '{output_path}'. " - f"Only {SUPPORTED_TYPES} are supported." - ) - - -def process_dataset( - data: str | Path, - output_path: str | Path, - processor: str | Path | PreTrainedTokenizerBase, - prompt_tokens: str | Path, - output_tokens: str | Path, - processor_args: dict[str, Any] | None = None, - data_args: dict[str, Any] | None = None, # noqa: ARG001 - short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE, - pad_char: str | None = None, - concat_delimiter: str | None = None, - push_to_hub: bool = False, - hub_dataset_id: str | None = None, - random_seed: int = 42, -) -> None: - """ - Main method to process and save a dataset with sampled prompt/output token counts. - - :param data: Path or identifier for dataset input. - :param output_path: File path to save the processed dataset. - :param processor: Tokenizer object or its config. - :param prompt_tokens: Prompt token config string or file. - :param output_tokens: Output token config string or file. - :param processor_args: Optional processor arguments. - :param data_args: Optional data loading arguments. - :param short_prompt_strategy: Strategy for handling short prompts. - :param pad_char: Character used when padding short prompts. - :param concat_delimiter: Delimiter for concatenation strategy. - :param push_to_hub: Whether to push to Hugging Face Hub. - :param hub_dataset_id: Dataset ID on Hugging Face Hub. - :param random_seed: Seed for random sampling. - :raises ValueError: If output path is invalid or pushing conditions unmet. - """ - - _validate_output_suffix(output_path) - logger.info( - f"Starting dataset conversion | Input: {data} | Output directory: {output_path}" - ) - - dataset, column_mappings = None, None - tokenizer = check_load_processor( - processor, - processor_args, - "dataset conversion.", - ) - prompt_column = column_mappings.get("prompt_column") # type: ignore[attr-defined] - output_column = column_mappings.get( # type: ignore[attr-defined] - "output_tokens_count_column", "output_tokens_count" - ) - - prompt_tokens_cfg = TokensConfig.parse_str(prompt_tokens) - output_tokens_cfg = TokensConfig.parse_str(output_tokens) - - prompt_token_sampler = iter( - IntegerRangeSampler( - average=prompt_tokens_cfg.average, - variance=prompt_tokens_cfg.stdev, - min_value=prompt_tokens_cfg.min, - max_value=prompt_tokens_cfg.max, - random_seed=random_seed, - ) - ) - - output_token_sampler = iter( - IntegerRangeSampler( - average=output_tokens_cfg.average, - variance=output_tokens_cfg.stdev, - min_value=output_tokens_cfg.min, - max_value=output_tokens_cfg.max, - random_seed=random_seed, - ) - ) - - dataset_iterator = iter(dataset) # type: ignore[call-overload] - processed_prompts = [] - prompt_handler = STRATEGY_HANDLERS[short_prompt_strategy] - - for prompt_row in dataset_iterator: - prompt_text = prompt_row[prompt_column] - target_prompt_len = next(prompt_token_sampler) - - prompt_text = prompt_handler( - current_prompt=prompt_text, - min_prompt_tokens=target_prompt_len, - dataset_iterator=dataset_iterator, - prompt_column=prompt_column, - tokenizer=tokenizer, - pad_char=pad_char, - concat_delimiter=concat_delimiter, - ) - if prompt_text is None: - continue - - tokens = tokenizer.encode(prompt_text) - if len(tokens) > target_prompt_len: - prompt_text = tokenizer.decode(tokens[:target_prompt_len]) - - processed_prompt = prompt_row.copy() - processed_prompt[prompt_column] = prompt_text - processed_prompt["prompt_tokens_count"] = target_prompt_len - processed_prompt[output_column] = next(output_token_sampler) - - processed_prompts.append(processed_prompt) - - if not processed_prompts: - logger.error("No prompts remained after processing") - return - - logger.info(f"Generated processed dataset with {len(processed_prompts)} prompts") - - processed_dataset = Dataset.from_list(processed_prompts) - save_dataset_to_file(processed_dataset, output_path) - logger.info(f"Conversion completed. Dataset saved to: {output_path}") - - if push_to_hub: - push_dataset_to_hub(hub_dataset_id, processed_dataset) - logger.info(f"Pushed dataset to: {hub_dataset_id}") - - -def push_dataset_to_hub( - hub_dataset_id: str | None, - processed_dataset: Dataset, -) -> None: - """ - Pushes the processed dataset to Hugging Face Hub using HF_TOKEN. - - :param hub_dataset_id: Identifier on the Hub to push to. - :param processed_dataset: HuggingFace Dataset object. - :raises ValueError: If hub_dataset_id or HF_TOKEN is not available. - """ - - hf_token = os.environ.get("HF_TOKEN") - if not hub_dataset_id or not hf_token: - raise ValueError( - "hub_dataset_id and HF_TOKEN env var must be provided when push_to_hub" - " is True" - ) - processed_dataset.push_to_hub(hub_dataset_id, token=hf_token) diff --git a/tests/unit/data/deserializers/test_synthetic.py b/tests/unit/data/deserializers/test_synthetic.py index 468c4c8e9..eda02ef58 100644 --- a/tests/unit/data/deserializers/test_synthetic.py +++ b/tests/unit/data/deserializers/test_synthetic.py @@ -11,11 +11,14 @@ import yaml from datasets import IterableDataset -from guidellm.data.deserializers.deserializer import DataNotSupportedError +from guidellm.data import config as config_module from guidellm.data.deserializers.synthetic import ( SyntheticTextDataset, - SyntheticTextDatasetConfig, SyntheticTextDatasetDeserializer, +) +from guidellm.data.schemas import ( + DataNotSupportedError, + SyntheticTextDatasetConfig, SyntheticTextPrefixBucketConfig, ) @@ -409,13 +412,14 @@ def test_load_config_file_yaml(self): yaml_path = f.name try: - deserializer = SyntheticTextDatasetDeserializer() - config = deserializer._load_config_file(yaml_path) + loaded_config = config_module._load_config_file( + yaml_path, SyntheticTextDatasetConfig, + ) - assert config.prompt_tokens == 60 - assert config.output_tokens == 15 - assert config.source == "yaml_test.txt" - assert config.prefix_buckets[0].prefix_tokens == 3 # type: ignore [index] + assert loaded_config.prompt_tokens == 60 + assert loaded_config.output_tokens == 15 + assert loaded_config.source == "yaml_test.txt" + assert loaded_config.prefix_buckets[0].prefix_tokens == 3 # type: ignore [index] finally: Path(yaml_path).unlink() @@ -438,12 +442,13 @@ def test_load_config_file_config_extension(self): config_path = f.name try: - deserializer = SyntheticTextDatasetDeserializer() - config = deserializer._load_config_file(config_path) + loaded_config = config_module._load_config_file( + config_path, SyntheticTextDatasetConfig, + ) - assert config.prompt_tokens == 90 - assert config.output_tokens == 35 - assert config.prefix_buckets[0].prefix_tokens == 2 # type: ignore [index] + assert loaded_config.prompt_tokens == 90 + assert loaded_config.output_tokens == 35 + assert loaded_config.prefix_buckets[0].prefix_tokens == 2 # type: ignore [index] finally: Path(config_path).unlink() @@ -454,11 +459,12 @@ def test_load_config_str_json(self): ### WRITTEN BY AI ### """ json_str = '{"prompt_tokens": 50, "output_tokens": 25}' - deserializer = SyntheticTextDatasetDeserializer() - config = deserializer._load_config_str(json_str) + loaded_config = config_module._load_config_str( + json_str, SyntheticTextDatasetConfig, + ) - assert config.prompt_tokens == 50 - assert config.output_tokens == 25 + assert loaded_config.prompt_tokens == 50 + assert loaded_config.output_tokens == 25 @pytest.mark.smoke def test_load_config_str_key_value(self): @@ -467,11 +473,12 @@ def test_load_config_str_key_value(self): ### WRITTEN BY AI ### """ kv_str = "prompt_tokens=50,output_tokens=25" - deserializer = SyntheticTextDatasetDeserializer() - config = deserializer._load_config_str(kv_str) + loaded_config = config_module._load_config_str( + kv_str, SyntheticTextDatasetConfig, + ) - assert config.prompt_tokens == 50 - assert config.output_tokens == 25 + assert loaded_config.prompt_tokens == 50 + assert loaded_config.output_tokens == 25 @pytest.mark.sanity def test_load_config_str_invalid_format(self): @@ -479,9 +486,10 @@ def test_load_config_str_invalid_format(self): ### WRITTEN BY AI ### """ - deserializer = SyntheticTextDatasetDeserializer() with pytest.raises(DataNotSupportedError, match="Unsupported string data"): - deserializer._load_config_str("invalid_format_string") + config_module._load_config_str( + "invalid_format_string", SyntheticTextDatasetConfig, + ) @pytest.mark.regression def test_load_config_file_non_existent(self): @@ -489,9 +497,10 @@ def test_load_config_file_non_existent(self): ### WRITTEN BY AI ### """ - deserializer = SyntheticTextDatasetDeserializer() - config = deserializer._load_config_file("/non/existent/path.config") - assert config is None + loaded_config = config_module._load_config_file( + "/non/existent/path.config", SyntheticTextDatasetConfig, + ) + assert loaded_config is None @pytest.mark.regression def test_load_config_str_non_string(self): @@ -499,9 +508,8 @@ def test_load_config_str_non_string(self): ### WRITTEN BY AI ### """ - deserializer = SyntheticTextDatasetDeserializer() - config = deserializer._load_config_str(123) - assert config is None + loaded_config = config_module._load_config_str(123, SyntheticTextDatasetConfig) + assert loaded_config is None @pytest.mark.smoke def test_call_with_config_object(self, mock_tokenizer): @@ -509,11 +517,11 @@ def test_call_with_config_object(self, mock_tokenizer): ### WRITTEN BY AI ### """ - config = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=25) + config_input = SyntheticTextDatasetConfig(prompt_tokens=50, output_tokens=25) deserializer = SyntheticTextDatasetDeserializer() result = deserializer( - data=config, + data=config_input, data_kwargs={}, processor_factory=lambda: mock_tokenizer, random_seed=42, diff --git a/tests/unit/data/test_entrypoints.py b/tests/unit/data/test_entrypoints.py new file mode 100644 index 000000000..49f35e9c7 --- /dev/null +++ b/tests/unit/data/test_entrypoints.py @@ -0,0 +1,1950 @@ +""" +Unit tests for guidellm.data.entrypoints module, specifically process_dataset function. +""" + +import json +import os +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +if TYPE_CHECKING: + from collections.abc import Iterator + +import pytest +import yaml +from datasets import Dataset +from transformers import PreTrainedTokenizerBase + +from guidellm.data.entrypoints import ( + PromptTooShortError, + ShortPromptStrategy, + ShortPromptStrategyHandler, + process_dataset, + push_dataset_to_hub, +) + + +@pytest.fixture +def tokenizer_mock(): + """Fixture to provide a mocked tokenizer.""" + tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + + # Simple tokenizer: each character is a token + def encode_side_effect(text): + if not text: + return [] + # Count tokens as roughly one per character for simplicity + return list(range(len(text))) + + def decode_side_effect(tokens, skip_special_tokens=False): + if not tokens: + return "" + # Simple decode: return a string representation + return "".join(chr(65 + (t % 26)) for t in tokens[:100]) + + tokenizer.encode.side_effect = encode_side_effect + tokenizer.decode.side_effect = decode_side_effect + return tokenizer + + +@pytest.fixture +def sample_dataset_default_columns(): + """Sample dataset with default column names.""" + return Dataset.from_dict({ + "prompt": [ + ( + "This is a very long prompt that should be sufficient for " + "testing purposes. " + ) * 10, + "Short.", + ( + "Another very long prompt for testing the dataset processing " + "functionality. " + ) * 10, + ], + }) + + +@pytest.fixture +def sample_dataset_custom_columns(): + """Sample dataset with custom column names requiring mapping.""" + return Dataset.from_dict({ + "question": [ + ( + "What is the meaning of life? This is a longer question that " + "should work for testing. " + ) * 10, + ( + "How does this work? Let me explain in detail how this system " + "functions. " + ) * 10, + ( + "Tell me about machine learning. Machine learning is a " + "fascinating field. " + ) * 10, + ], + }) + + +@pytest.fixture +def sample_dataset_with_prefix(): + """Sample dataset with prefix column.""" + return Dataset.from_dict({ + "prompt": [ + ( + "This is a long prompt that should be sufficient for testing " + "purposes. " + ) * 10, + "Another long prompt here that will work for testing. " * 10, + "Yet another long prompt for testing purposes. " * 10, + ], + "system_prompt": [ + "You are a helpful assistant.", + "You are a helpful assistant.", + "You are a helpful assistant.", + ], + }) + + +@pytest.fixture +def sample_config_json(): + """Sample config as JSON string.""" + return '{"prompt_tokens": 50, "output_tokens": 30}' + + +@pytest.fixture +def sample_config_key_value(): + """Sample config as key-value pairs.""" + return "prompt_tokens=50,output_tokens=30" + + +@pytest.fixture +def temp_output_path(tmp_path): + """Temporary file path for output.""" + return tmp_path / "output.json" + + +class TestProcessDatasetShortPromptStrategies: + """Test cases for different ShortPromptStrategy types.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_ignore_strategy( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + temp_output_path, + ): + """ + Test IGNORE strategy filters out short prompts. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + short_prompt_strategy=ShortPromptStrategy.IGNORE, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify that short prompts were filtered out + # The second prompt "Short." is only 6 characters, which is less than 50 tokens + # So it should be filtered out + assert len(saved_dataset) <= 2 # At most 2 prompts should remain + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_concatenate_strategy( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_config_json, + temp_output_path, + ): + """ + Test that the CONCATENATE strategy merges short prompts with subsequent rows. + ## WRITTEN BY AI ## + """ + # Create a dataset with short prompts that can be concatenated to reach target + # Use a lower target (15 tokens) so concatenation is achievable + short_config = '{"prompt_tokens": 15, "output_tokens": 10}' + short_prompts_dataset = Dataset.from_dict({ + "prompt": [ + "A", # 1 char = 1 token + "B", # 1 char = 1 token + "C", # 1 char = 1 token + "D", # 1 char = 1 token + "E", # 1 char = 1 token + "F", # 1 char = 1 token + "G", # 1 char = 1 token + "H", # 1 char = 1 token + "I", # 1 char = 1 token + "J", # 1 char = 1 token + "K", # 1 char = 1 token + "L", # 1 char = 1 token + "M", # 1 char = 1 token + "N", # 1 char = 1 token + "O", # 1 char = 1 token + "P", # 1 char = 1 token + "Q", # 1 char = 1 token + "R", # 1 char = 1 token + "S", # 1 char = 1 token + "T", # 1 char = 1 token + ], + }) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = short_prompts_dataset + + # Run process_dataset with the `concatenate` strategy + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=short_config, + short_prompt_strategy=ShortPromptStrategy.CONCATENATE, + concat_delimiter="\n", + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed + assert len(saved_dataset) > 0 + + # Verify concatenation occurred: check for delimiter and that prompts + # meet minimum token count + concatenated_found = False + for row in saved_dataset: + prompt_text = row["prompt"] + # Check that delimiter is present (indicating concatenation) + if "\n" in prompt_text: + concatenated_found = True + # Verify that multiple single-character prompts are present + # The concatenated prompt should contain multiple letters + # separated by newlines + parts = prompt_text.split("\n") + assert len(parts) >= 2, ( + f"Concatenated prompt should contain multiple parts " + f"separated by delimiter, got: {prompt_text[:100]}..." + ) + # Verify token counts meet minimum requirements + actual_tokens = len(tokenizer_mock.encode(prompt_text)) + assert actual_tokens >= 15, ( + f"Concatenated prompt should have at least 15 tokens, " + f"got {actual_tokens}" + ) + assert row["prompt_tokens_count"] == actual_tokens + assert row["prompt_tokens_count"] >= 15 + + # Verify that at least some concatenation occurred + # (Short single-character prompts should have been concatenated with + # subsequent rows) + assert concatenated_found, ( + "Expected to find concatenated prompts with delimiter" + ) + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_pad_strategy( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + temp_output_path, + ): + """ + Test PAD strategy adds padding to short prompts. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset with pad strategy + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + short_prompt_strategy=ShortPromptStrategy.PAD, + pad_char="X", + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed + assert len(saved_dataset) > 0 + + # Get original prompts for comparison + sample_dataset_default_columns["prompt"] + + # Check that prompts have been padded (they should be longer) + for row in saved_dataset: + assert "prompt" in row + assert len(row["prompt"]) > 0 + + # Verify that prompts meet minimum token count requirements + actual_tokens = len(tokenizer_mock.encode(row["prompt"])) + assert actual_tokens >= 50, \ + f"Padded prompt should have at least 50 tokens, got {actual_tokens}" + assert row["prompt_tokens_count"] == actual_tokens + + # For the "Short." prompt (index 1), verify it was padded + # The original "Short." is only 6 characters, so if it was + # processed, it should have been padded to meet the minimum token + # requirement + if "Short." in row["prompt"] or len(row["prompt"]) > 10: + # If this is the short prompt, verify it was padded + assert actual_tokens >= 50, ( + f"Short prompt should have been padded to at least 50 " + f"tokens, got {actual_tokens}" + ) + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_error_strategy( + self, + mock_check_processor, + mock_deserializer_factory_class, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + temp_output_path, + ): + """ + Test that the `ERROR` strategy raises PromptTooShortError for short prompts. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset with error strategy - should raise exception + with pytest.raises(PromptTooShortError): + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + short_prompt_strategy=ShortPromptStrategy.ERROR, + ) + + +class TestProcessDatasetColumnNames: + """Test cases for different column name scenarios.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_default_columns( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + temp_output_path, + ): + """ + Test process_dataset works with default column names (no mapping required). + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset without column mapping + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed correctly + assert len(saved_dataset) > 0 + for row in saved_dataset: + assert "prompt" in row + assert "prompt_tokens_count" in row + assert "output_tokens_count" in row + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_custom_columns_with_mapping( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_custom_columns, + sample_config_json, + temp_output_path, + ): + """ + Test process_dataset works with custom column names via explicit mapping. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_custom_columns + ) + + # Run process_dataset with column mapping + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + data_column_mapper={"text_column": "question"}, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed correctly + assert len(saved_dataset) > 0 + for row in saved_dataset: + assert "question" in row + assert "prompt_tokens_count" in row + assert "output_tokens_count" in row + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_with_prefix_column( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_with_prefix, + sample_config_json, + temp_output_path, + ): + """ + Test process_dataset handles prefix column correctly. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_with_prefix + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed correctly + assert len(saved_dataset) > 0 + for row in saved_dataset: + assert "prompt" in row + assert "system_prompt" in row + assert "prompt_tokens_count" in row + assert "output_tokens_count" in row + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_with_instruction_column( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_config_json, + temp_output_path, + ): + """ + Test process_dataset works with 'instruction' column (default text_column). + ## WRITTEN BY AI ## + """ + # Create dataset with 'instruction' column (one of the default + # text_column names) + dataset = Dataset.from_dict({ + "instruction": [ + "Follow these instructions carefully. " * 20, + "Complete the task as described. " * 20, + ], + }) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = dataset + + # Run process_dataset without column mapping (should auto-detect 'instruction') + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed correctly + assert len(saved_dataset) > 0 + + +class TestProcessDatasetConfigFormats: + """Test cases for different config format inputs.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_config_json( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + temp_output_path, + ): + """ + Test process_dataset accepts config as JSON string. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + # Run process_dataset with JSON config + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_config_key_value( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_key_value, + temp_output_path, + ): + """ + Test process_dataset accepts config as key-value pairs. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset with key-value config + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_key_value, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_config_file_json( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + tmp_path, + ): + """ + Test process_dataset accepts config from JSON file. + ## WRITTEN BY AI ## + """ + # Create a temporary JSON config file + config_file = tmp_path / "config.json" + config_data = {"prompt_tokens": 50, "output_tokens": 30} + config_file.write_text(json.dumps(config_data)) + + output_path = tmp_path / "output.json" + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset with JSON file config + process_dataset( + data="test_data", + output_path=output_path, + processor=tokenizer_mock, + config=str(config_file), + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_config_file_yaml( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + tmp_path, + ): + """ + Test process_dataset accepts config from YAML file. + ## WRITTEN BY AI ## + """ + # Create a temporary YAML config file + config_file = tmp_path / "config.yaml" + config_data = {"prompt_tokens": 50, "output_tokens": 30} + config_file.write_text(yaml.dump(config_data)) + + output_path = tmp_path / "output.json" + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset with YAML file config + process_dataset( + data="test_data", + output_path=output_path, + processor=tokenizer_mock, + config=str(config_file), + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_config_file_config_extension( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + tmp_path, + ): + """ + Test process_dataset accepts config from .config file. + ## WRITTEN BY AI ## + """ + # Create a temporary .config file + config_file = tmp_path / "config.config" + config_data = {"prompt_tokens": 50, "output_tokens": 30} + config_file.write_text(yaml.dump(config_data)) + + output_path = tmp_path / "output.json" + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset with .config file + process_dataset( + data="test_data", + output_path=output_path, + processor=tokenizer_mock, + config=str(config_file), + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + + +class TestProcessDatasetIntegration: + """Integration tests for process_dataset function.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_successful_processing( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + sample_config_json, + temp_output_path, + ): + """ + Test successful processing with valid dataset. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + ) + + # Verify all expected calls were made + mock_check_processor.assert_called_once() + mock_deserializer_factory_class.deserialize.assert_called_once() + assert mock_save_to_file.called + + # Verify the saved dataset structure + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + assert len(saved_dataset) > 0 + + # Verify each row has required fields + for row in saved_dataset: + assert "prompt" in row + assert "prompt_tokens_count" in row + assert "output_tokens_count" in row + assert isinstance(row["prompt_tokens_count"], int) + assert isinstance(row["output_tokens_count"], int) + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_empty_after_filtering( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_config_json, + temp_output_path, + ): + """ + Test handling of empty dataset after filtering. + ## WRITTEN BY AI ## + """ + # Create dataset with only very short prompts that will be filtered out + dataset = Dataset.from_dict({ + # Very short prompts (1 char each, less than 50 tokens) + "prompt": ["A", "B", "C"], + }) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = dataset + + # Run process_dataset with IGNORE strategy + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + short_prompt_strategy=ShortPromptStrategy.IGNORE, + ) + + # Verify all expected calls were made (even though dataset is empty) + mock_check_processor.assert_called_once() + mock_deserializer_factory_class.deserialize.assert_called_once() + # When all prompts are filtered out, save_dataset_to_file is not called + # (the function returns early in _finalize_processed_dataset) + # This is expected behavior - the function handles empty datasets gracefully + assert not mock_save_to_file.called + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_with_prefix_tokens( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_with_prefix, + temp_output_path, + ): + """ + Test process_dataset handles trimming prefix tokens correctly. + ## WRITTEN BY AI ## + """ + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_with_prefix + ) + config = '{"prompt_tokens": 50, "output_tokens": 30, "prefix_tokens_max": 10}' + + # Run process_dataset with prefix_tokens + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed + assert len(saved_dataset) > 0 + + # Verify that prefix lengths are correct (trimmed to 10 tokens) + for row in saved_dataset: + assert "system_prompt" in row + prefix_text = row["system_prompt"] + + # Verify prefix is trimmed to exactly 10 tokens + prefix_tokens = len(tokenizer_mock.encode(prefix_text)) + assert prefix_tokens == 10, ( + f"Prefix should be trimmed to 10 tokens, got {prefix_tokens} " + f"for prefix: {prefix_text[:50]}..." + ) + + # Verify prompt and output token counts are present + assert "prompt_tokens_count" in row + assert "output_tokens_count" in row + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_include_prefix_in_token_count( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_with_prefix, + sample_config_json, + temp_output_path, + ): + """Test process_dataset with include_prefix_in_token_count flag.""" + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_with_prefix + ) + + # Run process_dataset with include_prefix_in_token_count + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=sample_config_json, + include_prefix_in_token_count=True, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed + assert len(saved_dataset) > 0 + + # Verify that the token count accounts for the prefix + # When include_prefix_in_token_count=True, the prefix tokens are subtracted from + # the target prompt length, so prompt_tokens_count is just the prompt part, + # but the total effective tokens (prefix + prompt) should equal the target + for row in saved_dataset: + assert "system_prompt" in row + assert "prompt" in row + + prefix_text = row["system_prompt"] + prompt_text = row["prompt"] + + # Calculate token counts + prefix_tokens = len(tokenizer_mock.encode(prefix_text)) + prompt_tokens = len(tokenizer_mock.encode(prompt_text)) + stored_count = row["prompt_tokens_count"] + + # Verify stored count matches actual prompt token count + assert stored_count == prompt_tokens, ( + f"prompt_tokens_count should match actual prompt tokens. " + f"Expected {prompt_tokens}, got {stored_count}" + ) + + # Verify that the prompt was adjusted to account for prefix + # The total effective tokens (prefix + prompt) should be close to + # the target (50). The prompt should have been reduced by the + # prefix token count + total_effective_tokens = prefix_tokens + prompt_tokens + # Allow some variance due to sampling, but total should be around + # target + assert 40 <= total_effective_tokens <= 60, ( + f"Total effective tokens (prefix: {prefix_tokens} + prompt: " + f"{prompt_tokens} = {total_effective_tokens}) should be close " + f"to target of 50 when include_prefix_in_token_count=True" + ) + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_with_different_config_values( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + sample_dataset_default_columns, + temp_output_path, + ): + """ + Test process_dataset with different config values (min, max, stdev). + ## WRITTEN BY AI ## + """ + # Create config with min, max, and stdev + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 50, ' + '"prompt_tokens_max": 150, "prompt_tokens_stdev": 10, ' + '"output_tokens": 50, "output_tokens_min": 25, ' + '"output_tokens_max": 75, "output_tokens_stdev": 5}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + sample_dataset_default_columns + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Verify save_dataset_to_file was called + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify dataset was processed + assert len(saved_dataset) > 0 + + +@pytest.fixture +def large_dataset_for_validation(): + """Large dataset with many rows for statistical validation.""" + # Create 20 rows with long prompts to ensure they pass filtering + prompts = [ + f"This is a very long prompt number {i} for testing purposes. " * 15 + for i in range(20) + ] + return Dataset.from_dict({"prompt": prompts}) + + +class TestProcessDatasetConfigValidation: + """Test cases for validating config settings by verifying actual token counts.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_fixed_prompt_token_count( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that fixed prompt token counts (min=max) are respected. + ## WRITTEN BY AI ## + """ + # Config with fixed prompt tokens (min=max=100) + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 100, ' + '"prompt_tokens_max": 100, "output_tokens": 50}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify all prompts have exactly 100 tokens + for row in saved_dataset: + assert row["prompt_tokens_count"] == 100 + # Verify actual tokenized length matches + actual_tokens = len(tokenizer_mock.encode(row["prompt"])) + assert actual_tokens == 100, f"Expected 100 tokens, got {actual_tokens}" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_fixed_output_token_count( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that fixed output token counts (min=max) are respected. + ## WRITTEN BY AI ## + """ + # Config with fixed output tokens (min=max=75) + config = ( + '{"prompt_tokens": 100, "output_tokens": 75, ' + '"output_tokens_min": 75, "output_tokens_max": 75}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify all outputs have exactly 75 tokens + for row in saved_dataset: + assert row["output_tokens_count"] == 75 + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_prompt_min_max_constraints( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that prompt token counts respect min/max constraints. + ## WRITTEN BY AI ## + """ + # Config with prompt min=80, max=120 + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 80, ' + '"prompt_tokens_max": 120, "output_tokens": 50}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify all prompt token counts are within bounds + prompt_counts = [row["prompt_tokens_count"] for row in saved_dataset] + assert len(prompt_counts) > 0 + assert min(prompt_counts) >= 80, ( + f"Found prompt count {min(prompt_counts)} below min 80" + ) + assert max(prompt_counts) <= 120, ( + f"Found prompt count {max(prompt_counts)} above max 120" + ) + + # Verify actual tokenized lengths match stored counts + for row in saved_dataset: + actual_tokens = len(tokenizer_mock.encode(row["prompt"])) + assert actual_tokens == row["prompt_tokens_count"], ( + f"Stored count {row['prompt_tokens_count']} doesn't match " + f"actual {actual_tokens}" + ) + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_output_min_max_constraints( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that output token counts respect min/max constraints. + ## WRITTEN BY AI ## + """ + # Config with output min=40, max=60 + config = ( + '{"prompt_tokens": 100, "output_tokens": 50, ' + '"output_tokens_min": 40, "output_tokens_max": 60}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify all output token counts are within bounds + output_counts = [row["output_tokens_count"] for row in saved_dataset] + assert len(output_counts) > 0 + assert min(output_counts) >= 40, ( + f"Found output count {min(output_counts)} below min 40" + ) + assert max(output_counts) <= 60, ( + f"Found output count {max(output_counts)} above max 60" + ) + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_prompt_stdev_distribution( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that prompt token counts follow expected distribution with stdev. + ## WRITTEN BY AI ## + """ + # Config with prompt average=100, stdev=10, min=70, max=130 + config = ( + '{"prompt_tokens": 100, "prompt_tokens_stdev": 10, ' + '"prompt_tokens_min": 70, "prompt_tokens_max": 130, ' + '"output_tokens": 50}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + random_seed=42, # Fixed seed for reproducibility + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify distribution properties + prompt_counts = [row["prompt_tokens_count"] for row in saved_dataset] + assert len(prompt_counts) > 0 + + # Check bounds + assert min(prompt_counts) >= 70 + assert max(prompt_counts) <= 130 + + # Check mean is close to average (within 2 stdev of the mean of means) + mean_count = sum(prompt_counts) / len(prompt_counts) + # With enough samples, mean should be close to 100 + assert 90 <= mean_count <= 110, f"Mean {mean_count} not close to expected 100" + + # Verify actual tokenized lengths match stored counts + for row in saved_dataset: + actual_tokens = len(tokenizer_mock.encode(row["prompt"])) + assert actual_tokens == row["prompt_tokens_count"] + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_output_stdev_distribution( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that output token counts follow expected distribution with stdev. + ## WRITTEN BY AI ## + """ + # Config with output average=50, stdev=5, min=35, max=65 + config = ( + '{"prompt_tokens": 100, "output_tokens": 50, ' + '"output_tokens_stdev": 5, "output_tokens_min": 35, ' + '"output_tokens_max": 65}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + random_seed=42, # Fixed seed for reproducibility + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify distribution properties + output_counts = [row["output_tokens_count"] for row in saved_dataset] + assert len(output_counts) > 0 + + # Check bounds + assert min(output_counts) >= 35 + assert max(output_counts) <= 65 + + # Check mean is close to average + mean_count = sum(output_counts) / len(output_counts) + assert 45 <= mean_count <= 55, f"Mean {mean_count} not close to expected 50" + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_token_count_accuracy( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that stored token counts match actual tokenized lengths. + ## WRITTEN BY AI ## + """ + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 80, ' + '"prompt_tokens_max": 120, "output_tokens": 50}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify stored counts match actual tokenized lengths + for row in saved_dataset: + prompt_text = row["prompt"] + stored_count = row["prompt_tokens_count"] + actual_count = len(tokenizer_mock.encode(prompt_text)) + + assert actual_count == stored_count, ( + f"Stored count {stored_count} doesn't match actual tokenized " + f"length {actual_count} for prompt: {prompt_text[:50]}..." + ) + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_prompt_trimming_accuracy( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that prompts exceeding target length are trimmed correctly. + ## WRITTEN BY AI ## + """ + # Use a small max to force trimming + config = ( + '{"prompt_tokens": 50, "prompt_tokens_min": 50, ' + '"prompt_tokens_max": 50, "output_tokens": 30}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify all prompts are trimmed to exactly 50 tokens + for row in saved_dataset: + actual_tokens = len(tokenizer_mock.encode(row["prompt"])) + assert actual_tokens == 50, \ + f"Prompt not trimmed correctly: expected 50 tokens, got {actual_tokens}" + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_prompt_padding_accuracy( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + temp_output_path, + ): + """ + Test that prompts below target length are padded correctly with PAD strategy. + ## WRITTEN BY AI ## + """ + # Create dataset with short prompts + short_prompts = ["Short", "Tiny", "Small prompt"] * 5 + dataset = Dataset.from_dict({"prompt": short_prompts}) + + # Use a large target to force padding + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 100, ' + '"prompt_tokens_max": 100, "output_tokens": 30}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = dataset + + # Run process_dataset with PAD strategy + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + short_prompt_strategy=ShortPromptStrategy.PAD, + pad_char="X", + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify all prompts are padded to exactly 100 tokens + pad_char_found = False + for row in saved_dataset: + prompt_text = row["prompt"] + actual_tokens = len(tokenizer_mock.encode(prompt_text)) + assert actual_tokens == 100, \ + f"Prompt not padded correctly: expected 100 tokens, got {actual_tokens}" + assert row["prompt_tokens_count"] == 100 + + # Verify that pad_char "X" appears in the padded prompts + # Since original prompts are short ("Short", "Tiny", "Small prompt"), + # they should all be padded with "X" characters + if "X" in prompt_text: + pad_char_found = True + # Verify that X characters appear at the end (where padding would be) + # or verify that the prompt contains X characters indicating + # padding + assert prompt_text.count("X") > 0, ( + f"Expected pad_char 'X' in padded prompt, but not found: " + f"{prompt_text[:100]}..." + ) + + # Verify that at least some prompts contain the pad character + assert pad_char_found, ( + "Expected to find pad_char 'X' in at least some padded prompts, " + "but none were found" + ) + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_combined_config_constraints( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test config with all parameters set (average, min, max, stdev) for + both prompt and output. + ## WRITTEN BY AI ## + """ + # Config with all parameters + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 80, ' + '"prompt_tokens_max": 120, "prompt_tokens_stdev": 10, ' + '"output_tokens": 50, "output_tokens_min": 40, ' + '"output_tokens_max": 60, "output_tokens_stdev": 5}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + random_seed=42, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Verify prompt constraints + prompt_counts = [row["prompt_tokens_count"] for row in saved_dataset] + assert min(prompt_counts) >= 80 + assert max(prompt_counts) <= 120 + + # Verify output constraints + output_counts = [row["output_tokens_count"] for row in saved_dataset] + assert min(output_counts) >= 40 + assert max(output_counts) <= 60 + + # Verify token count accuracy + for row in saved_dataset: + actual_tokens = len(tokenizer_mock.encode(row["prompt"])) + assert actual_tokens == row["prompt_tokens_count"] + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_edge_cases_token_counts( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test edge cases: very small token counts and min=max=1 token counts. + ## WRITTEN BY AI ## + """ + # Test 1: Very small token counts (use PAD strategy to ensure prompts + # are processed) + config_small = ( + '{"prompt_tokens": 7, "prompt_tokens_min": 5, ' + '"prompt_tokens_max": 10, "output_tokens": 5, ' + '"output_tokens_min": 3, "output_tokens_max": 8}' + ) + + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config_small, + short_prompt_strategy=ShortPromptStrategy.PAD, + pad_char="X", + ) + + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + prompt_counts = [row["prompt_tokens_count"] for row in saved_dataset] + output_counts = [row["output_tokens_count"] for row in saved_dataset] + + assert min(prompt_counts) >= 5 + assert max(prompt_counts) <= 10 + assert min(output_counts) >= 3 + assert max(output_counts) <= 8 + + # Test 2: min=max=1 (minimum valid value) - use PAD strategy to + # ensure processing + config_min = ( + '{"prompt_tokens": 1, "prompt_tokens_min": 1, ' + '"prompt_tokens_max": 1, "output_tokens": 1, ' + '"output_tokens_min": 1, "output_tokens_max": 1}' + ) + + mock_save_to_file.reset_mock() + # Create a dataset with very short prompts for this test + short_dataset = Dataset.from_dict({"prompt": ["A"] * 5}) + mock_deserializer_factory_class.deserialize.return_value = short_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config_min, + short_prompt_strategy=ShortPromptStrategy.PAD, + pad_char="X", + ) + + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + for row in saved_dataset: + assert row["prompt_tokens_count"] == 1 + assert row["output_tokens_count"] == 1 + + @pytest.mark.regression + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_no_stdev_behavior( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + large_dataset_for_validation, + temp_output_path, + ): + """ + Test that when stdev is not specified, values are uniformly + distributed within min/max. + ## WRITTEN BY AI ## + """ + # Config without stdev (omitted entirely) - should use uniform + # distribution + config = ( + '{"prompt_tokens": 100, "prompt_tokens_min": 90, ' + '"prompt_tokens_max": 110, "output_tokens": 50, ' + '"output_tokens_min": 45, "output_tokens_max": 55}' + ) + + # Setup mocks + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = ( + large_dataset_for_validation + ) + + # Run process_dataset + process_dataset( + data="test_data", + output_path=temp_output_path, + processor=tokenizer_mock, + config=config, + random_seed=42, + ) + + # Extract saved dataset + assert mock_save_to_file.called + call_args = mock_save_to_file.call_args + saved_dataset = call_args[0][0] + + # Without stdev, values should be uniformly distributed within min/max + prompt_counts = [row["prompt_tokens_count"] for row in saved_dataset] + output_counts = [row["output_tokens_count"] for row in saved_dataset] + + assert min(prompt_counts) >= 90 + assert max(prompt_counts) <= 110 + assert min(output_counts) >= 45 + assert max(output_counts) <= 55 + + +class TestShortPromptStrategyHandlers: + """Unit tests for individual short prompt strategy handler functions.""" + + @pytest.mark.sanity + def test_handle_ignore_strategy_too_short(self, tokenizer_mock): + """Test handle_ignore returns None for short prompts.""" + result = ShortPromptStrategyHandler.handle_ignore("short", 10, tokenizer_mock) + assert result is None + tokenizer_mock.encode.assert_called_with("short") + + @pytest.mark.sanity + def test_handle_ignore_strategy_sufficient_length(self, tokenizer_mock): + """Test handle_ignore returns prompt for sufficient length.""" + result = ShortPromptStrategyHandler.handle_ignore( + "long prompt", 5, tokenizer_mock + ) + assert result == "long prompt" + tokenizer_mock.encode.assert_called_with("long prompt") + + @pytest.mark.sanity + def test_handle_concatenate_strategy_enough_prompts(self, tokenizer_mock): + """Test handle_concatenate with enough prompts.""" + dataset_iter = iter([{"prompt": "longer"}]) + result = ShortPromptStrategyHandler.handle_concatenate( + "short", 10, dataset_iter, "prompt", tokenizer_mock, "\n" + ) + assert result == "short\nlonger" + + @pytest.mark.sanity + def test_handle_concatenate_strategy_not_enough_prompts(self, tokenizer_mock): + """Test handle_concatenate without enough prompts.""" + dataset_iter: Iterator = iter([]) + result = ShortPromptStrategyHandler.handle_concatenate( + "short", 10, dataset_iter, "prompt", tokenizer_mock, "" + ) + assert result is None + + @pytest.mark.sanity + def test_handle_pad_strategy(self, tokenizer_mock): + """Test handle_pad pads short prompts.""" + result = ShortPromptStrategyHandler.handle_pad("short", 10, tokenizer_mock, "p") + assert result.startswith("shortppppp") + + @pytest.mark.sanity + def test_handle_error_strategy_valid_prompt(self, tokenizer_mock): + """Test handle_error returns prompt for valid length.""" + result = ShortPromptStrategyHandler.handle_error( + "valid prompt", 5, tokenizer_mock + ) + assert result == "valid prompt" + tokenizer_mock.encode.assert_called_with("valid prompt") + + @pytest.mark.sanity + def test_handle_error_strategy_too_short_prompt(self, tokenizer_mock): + """Test handle_error raises error for short prompts.""" + with pytest.raises(PromptTooShortError): + ShortPromptStrategyHandler.handle_error("short", 10, tokenizer_mock) + + +class TestProcessDatasetPushToHub: + """Test cases for push_to_hub functionality.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.push_dataset_to_hub") + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_push_to_hub_called( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + mock_push, + tokenizer_mock, + tmp_path, + ): + """Test that push_to_hub is called when push_to_hub=True.""" + # Create a dataset with prompts long enough to be processed + sample_dataset = Dataset.from_dict({ + "prompt": ["abc " * 50], # Long enough + }) + + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = sample_dataset + + output_path = tmp_path / "output.json" + config = '{"prompt_tokens": 10, "output_tokens": 5}' + + process_dataset( + data="input", + output_path=output_path, + processor=tokenizer_mock, + config=config, + push_to_hub=True, + hub_dataset_id="id123", + ) + + # Verify push_to_hub was called with the correct arguments + assert mock_push.called + call_args = mock_push.call_args + assert call_args[0][0] == "id123" + assert isinstance(call_args[0][1], Dataset) + + @pytest.mark.sanity + @patch("guidellm.data.entrypoints.push_dataset_to_hub") + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_process_dataset_push_to_hub_not_called( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + mock_push, + tokenizer_mock, + tmp_path, + ): + """Test that push_to_hub is not called when push_to_hub=False.""" + # Create a dataset with prompts long enough to be processed + sample_dataset = Dataset.from_dict({ + "prompt": ["abc " * 50], # Long enough + }) + + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = sample_dataset + + output_path = tmp_path / "output.json" + config = '{"prompt_tokens": 10, "output_tokens": 5}' + + process_dataset( + data="input", + output_path=output_path, + processor=tokenizer_mock, + config=config, + push_to_hub=False, + ) + + # Verify push_to_hub was not called + mock_push.assert_not_called() + + @pytest.mark.regression + def test_push_dataset_to_hub_success(self): + """Test push_dataset_to_hub success case.""" + os.environ["HF_TOKEN"] = "token" + mock_dataset = MagicMock(spec=Dataset) + push_dataset_to_hub("dataset_id", mock_dataset) + mock_dataset.push_to_hub.assert_called_once_with("dataset_id", token="token") + + @pytest.mark.regression + def test_push_dataset_to_hub_error_no_env(self): + """Test push_dataset_to_hub raises error when HF_TOKEN is missing.""" + if "HF_TOKEN" in os.environ: + del os.environ["HF_TOKEN"] + mock_dataset = MagicMock(spec=Dataset) + with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): + push_dataset_to_hub("dataset_id", mock_dataset) + + @pytest.mark.regression + def test_push_dataset_to_hub_error_no_id(self): + """Test push_dataset_to_hub raises error when hub_dataset_id is missing.""" + os.environ["HF_TOKEN"] = "token" + mock_dataset = MagicMock(spec=Dataset) + with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): + push_dataset_to_hub(None, mock_dataset) + + +class TestProcessDatasetStrategyHandlerIntegration: + """Test cases for strategy handler integration with process_dataset.""" + + @pytest.mark.smoke + @patch("guidellm.data.entrypoints.save_dataset_to_file") + @patch("guidellm.data.entrypoints.DatasetDeserializerFactory") + @patch("guidellm.data.entrypoints.check_load_processor") + def test_strategy_handler_called( + self, + mock_check_processor, + mock_deserializer_factory_class, + mock_save_to_file, + tokenizer_mock, + tmp_path, + ): + """Test that strategy handlers are called during dataset processing.""" + from guidellm.data.entrypoints import STRATEGY_HANDLERS + mock_handler = MagicMock(return_value="processed_prompt") + with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}): + # Create a dataset with prompts that need processing + sample_dataset = Dataset.from_dict({ + "prompt": [ + "abc" * 20, # Long enough to pass + "def" * 20, # Long enough to pass + ], + }) + + mock_check_processor.return_value = tokenizer_mock + mock_deserializer_factory_class.deserialize.return_value = sample_dataset + + output_path = tmp_path / "output.json" + config = '{"prompt_tokens": 10, "output_tokens": 5}' + + process_dataset( + data="input", + output_path=output_path, + processor=tokenizer_mock, + config=config, + short_prompt_strategy=ShortPromptStrategy.IGNORE, + ) + + # Verify that the handler was called during processing + # The handler is called for each row that needs processing + mock_deserializer_factory_class.deserialize.assert_called_once() + mock_check_processor.assert_called_once() + assert mock_save_to_file.called + # Verify handler was called (at least once if there are rows to process) + if len(sample_dataset) > 0: + assert mock_handler.called diff --git a/tests/unit/preprocess/__init__.py b/tests/unit/preprocess/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/preprocess/test_dataset.py b/tests/unit/preprocess/test_dataset.py deleted file mode 100644 index d7014e222..000000000 --- a/tests/unit/preprocess/test_dataset.py +++ /dev/null @@ -1,292 +0,0 @@ -import os -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch - -if TYPE_CHECKING: - from collections.abc import Iterator - -import pytest -from datasets import Dataset -from transformers import PreTrainedTokenizerBase - -from guidellm.preprocess.dataset import ( - STRATEGY_HANDLERS, - PromptTooShortError, - ShortPromptStrategy, - handle_concatenate_strategy, - handle_error_strategy, - handle_ignore_strategy, - handle_pad_strategy, - process_dataset, - push_dataset_to_hub, -) - - -@pytest.fixture -def tokenizer_mock(): - tokenizer = MagicMock(spec=PreTrainedTokenizerBase) - tokenizer.encode.side_effect = lambda x: [1] * len(x) - tokenizer.decode.side_effect = lambda x, *args, **kwargs: "".join( - str(item) for item in x - ) - return tokenizer - - -@pytest.mark.xfail(reason="old and broken", run=False) -@pytest.mark.smoke -@patch(f"{process_dataset.__module__}.guidellm_load_dataset") -@patch(f"{process_dataset.__module__}.check_load_processor") -@patch(f"{process_dataset.__module__}.Dataset") -@patch(f"{process_dataset.__module__}.IntegerRangeSampler") -def test_strategy_handler_called( - mock_sampler, - mock_dataset_class, - mock_check_processor, - mock_load_dataset, - tokenizer_mock, -): - mock_handler = MagicMock(return_value="processed_prompt") - with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}): - mock_dataset = [{"prompt": "abc"}, {"prompt": "def"}] - mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) - mock_check_processor.return_value = tokenizer_mock - mock_sampler.side_effect = lambda **kwargs: [10, 10] - - mock_dataset_obj = MagicMock(spec=Dataset) - mock_dataset_class.from_list.return_value = mock_dataset_obj - - process_dataset( - data="input", - output_path="output_dir/data.json", - processor=tokenizer_mock, - prompt_tokens="average=10,min=1", - output_tokens="average=10,min=1", - short_prompt_strategy=ShortPromptStrategy.IGNORE, - ) - - assert mock_handler.call_count == 2 - mock_load_dataset.assert_called_once() - mock_check_processor.assert_called_once() - - -@pytest.mark.sanity -def test_handle_ignore_strategy_too_short(tokenizer_mock): - result = handle_ignore_strategy("short", 10, tokenizer_mock) - assert result is None - tokenizer_mock.encode.assert_called_with("short") - - -@pytest.mark.sanity -def test_handle_ignore_strategy_sufficient_length(tokenizer_mock): - result = handle_ignore_strategy("long prompt", 5, tokenizer_mock) - assert result == "long prompt" - tokenizer_mock.encode.assert_called_with("long prompt") - - -@pytest.mark.sanity -def test_handle_concatenate_strategy_enough_prompts(tokenizer_mock): - dataset_iter = iter([{"prompt": "longer"}]) - result = handle_concatenate_strategy( - "short", 10, dataset_iter, "prompt", tokenizer_mock, "\n" - ) - assert result == "short\nlonger" - - -@pytest.mark.sanity -def test_handle_concatenate_strategy_not_enough_prompts(tokenizer_mock): - dataset_iter: Iterator = iter([]) - result = handle_concatenate_strategy( - "short", 10, dataset_iter, "prompt", tokenizer_mock, "" - ) - assert result is None - - -@pytest.mark.sanity -def test_handle_pad_strategy(tokenizer_mock): - result = handle_pad_strategy("short", 10, tokenizer_mock, "p") - assert result.startswith("shortppppp") - - -@pytest.mark.sanity -def test_handle_error_strategy_valid_prompt(tokenizer_mock): - result = handle_error_strategy("valid prompt", 5, tokenizer_mock) - assert result == "valid prompt" - tokenizer_mock.encode.assert_called_with("valid prompt") - - -@pytest.mark.sanity -def test_handle_error_strategy_too_short_prompt(tokenizer_mock): - with pytest.raises(PromptTooShortError): - handle_error_strategy("short", 10, tokenizer_mock) - - -@pytest.mark.xfail(reason="old and broken", run=False) -@pytest.mark.smoke -@patch(f"{process_dataset.__module__}.save_dataset_to_file") -@patch(f"{process_dataset.__module__}.Dataset") -@patch(f"{process_dataset.__module__}.guidellm_load_dataset") -@patch(f"{process_dataset.__module__}.check_load_processor") -@patch(f"{process_dataset.__module__}.IntegerRangeSampler") -def test_process_dataset_non_empty( - mock_sampler, - mock_check_processor, - mock_load_dataset, - mock_dataset_class, - mock_save_to_file, - tokenizer_mock, -): - mock_dataset = [{"prompt": "Hello"}, {"prompt": "How are you?"}] - mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) - mock_check_processor.return_value = tokenizer_mock - mock_sampler.side_effect = lambda **kwargs: [3, 3, 3] - - mock_dataset_obj = MagicMock(spec=Dataset) - mock_dataset_class.from_list.return_value = mock_dataset_obj - - output_path = "output_dir/data.json" - process_dataset( - data="input", - output_path=output_path, - processor=tokenizer_mock, - prompt_tokens="average=10,min=1", - output_tokens="average=10,min=1", - ) - - mock_load_dataset.assert_called_once() - mock_check_processor.assert_called_once() - mock_dataset_class.from_list.assert_called_once() - mock_save_to_file.assert_called_once_with(mock_dataset_obj, output_path) - - args, _ = mock_dataset_class.from_list.call_args - processed_list = args[0] - assert len(processed_list) == 2 - for item in processed_list: - assert "prompt" in item - assert "prompt_tokens_count" in item - assert "output_tokens_count" in item - assert len(tokenizer_mock.encode(item["prompt"])) <= 3 - - -@pytest.mark.xfail(reason="old and broken", run=False) -@pytest.mark.sanity -@patch(f"{process_dataset.__module__}.Dataset") -@patch(f"{process_dataset.__module__}.guidellm_load_dataset") -@patch(f"{process_dataset.__module__}.check_load_processor") -@patch(f"{process_dataset.__module__}.IntegerRangeSampler") -def test_process_dataset_empty_after_processing( - mock_sampler, - mock_check_processor, - mock_load_dataset, - mock_dataset_class, - tokenizer_mock, -): - mock_dataset = [{"prompt": ""}] - mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) - mock_check_processor.return_value = tokenizer_mock - mock_sampler.side_effect = lambda **kwargs: [10] - - process_dataset( - data="input", - output_path="output_dir/data.json", - processor=tokenizer_mock, - prompt_tokens="average=10,min=1", - output_tokens="average=10,min=1", - ) - - mock_load_dataset.assert_called_once() - mock_check_processor.assert_called_once() - mock_dataset_class.from_list.assert_not_called() - - -@pytest.mark.xfail(reason="old and broken", run=False) -@pytest.mark.smoke -@patch(f"{process_dataset.__module__}.push_dataset_to_hub") -@patch(f"{process_dataset.__module__}.Dataset") -@patch(f"{process_dataset.__module__}.guidellm_load_dataset") -@patch(f"{process_dataset.__module__}.check_load_processor") -@patch(f"{process_dataset.__module__}.IntegerRangeSampler") -def test_process_dataset_push_to_hub_called( - mock_sampler, - mock_check_processor, - mock_load_dataset, - mock_dataset_class, - mock_push, - tokenizer_mock, -): - mock_dataset = [{"prompt": "abc"}] - mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) - mock_check_processor.return_value = tokenizer_mock - mock_sampler.side_effect = lambda **kwargs: [3] - - mock_dataset_obj = MagicMock(spec=Dataset) - mock_dataset_class.from_list.return_value = mock_dataset_obj - - process_dataset( - data="input", - output_path="output_dir/data.json", - processor=tokenizer_mock, - prompt_tokens="average=10,min=1", - output_tokens="average=10,min=1", - push_to_hub=True, - hub_dataset_id="id123", - ) - mock_push.assert_called_once_with("id123", mock_dataset_obj) - - -@pytest.mark.xfail(reason="old and broken", run=False) -@pytest.mark.sanity -@patch(f"{process_dataset.__module__}.push_dataset_to_hub") -@patch(f"{process_dataset.__module__}.Dataset") -@patch(f"{process_dataset.__module__}.guidellm_load_dataset") -@patch(f"{process_dataset.__module__}.check_load_processor") -@patch(f"{process_dataset.__module__}.IntegerRangeSampler") -def test_process_dataset_push_to_hub_not_called( - mock_sampler, - mock_check_processor, - mock_load_dataset, - mock_dataset_class, - mock_push, - tokenizer_mock, -): - mock_dataset = [{"prompt": "abc"}] - mock_load_dataset.return_value = (mock_dataset, {"prompt_column": "prompt"}) - mock_check_processor.return_value = tokenizer_mock - mock_sampler.side_effect = lambda **kwargs: [3] - - mock_dataset_obj = MagicMock(spec=Dataset) - mock_dataset_class.from_list.return_value = mock_dataset_obj - - process_dataset( - data="input", - output_path="output_dir/data.json", - processor=tokenizer_mock, - prompt_tokens="average=10,min=1", - output_tokens="average=10,min=1", - push_to_hub=False, - ) - mock_push.assert_not_called() - - -@pytest.mark.regression -def test_push_dataset_to_hub_success(): - os.environ["HF_TOKEN"] = "token" - mock_dataset = MagicMock(spec=Dataset) - push_dataset_to_hub("dataset_id", mock_dataset) - mock_dataset.push_to_hub.assert_called_once_with("dataset_id", token="token") - - -@pytest.mark.regression -def test_push_dataset_to_hub_error_no_env(): - if "HF_TOKEN" in os.environ: - del os.environ["HF_TOKEN"] - mock_dataset = MagicMock(spec=Dataset) - with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): - push_dataset_to_hub("dataset_id", mock_dataset) - - -@pytest.mark.regression -def test_push_dataset_to_hub_error_no_id(): - os.environ["HF_TOKEN"] = "token" - mock_dataset = MagicMock(spec=Dataset) - with pytest.raises(ValueError, match="hub_dataset_id and HF_TOKEN"): - push_dataset_to_hub(None, mock_dataset)