diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 86bd88fd91..60705744e2 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -10,6 +10,9 @@ from dataclasses import dataclass, field from typing import Callable +from datasets import Dataset, DatasetDict +from torch.utils.data import DataLoader + @dataclass class DVCDatasetArguments: @@ -101,12 +104,13 @@ class DatasetArguments(CustomDatasetArguments): arguments to be able to specify them on the command line """ - dataset: str | None = field( + dataset: str | Dataset | DatasetDict | DataLoader | None = field( default=None, metadata={ "help": ( - "The name of the dataset to use (via the datasets library). " - "Supports input as a string or DatasetDict from HF" + "The dataset to use for calibration. Supports a dataset name " + "(str, via the datasets library), a DatasetDict from HF, or a " + "pre-built PyTorch DataLoader." ) }, ) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 69d1b47415..3478f08055 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -100,16 +100,24 @@ def _get_split_name(inp_str): def get_calibration_dataloader( dataset_args: DatasetArguments, processor: Processor, -) -> torch.utils.data.DataLoader: +) -> DataLoader | None: """ Get the dataloader used for oneshot calibration. + + If dataset_args.dataset is already a PyTorch DataLoader, + it is returned directly, bypassing dataset loading and tokenization. + :param dataset_args: DatasetArguments that contains the dataset parameters. :param processor: Processor or the tokenizer of the model. - :return: PyTorch dataloader object that contains the calibration dataset. + :return: PyTorch dataloader object that contains the calibration + dataset, or None for data-free flows. """ if dataset_args.dataset is None: # weight-only quantization or dynamic quantization - return + return None + + if isinstance(dataset_args.dataset, DataLoader): + return dataset_args.dataset datasets = get_processed_dataset( dataset_args=dataset_args, diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 34fc06bd43..9ab1df68e9 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -263,7 +263,7 @@ def oneshot( clear_sparse_session: bool = False, stage: str | None = None, # Dataset arguments - dataset: str | Dataset | DatasetDict | None = None, + dataset: str | Dataset | DatasetDict | DataLoader | None = None, dataset_config_name: str | None = None, dataset_path: str | None = None, splits: str | list[str] | dict[str, str] | None = None, @@ -340,8 +340,10 @@ def oneshot( :param stage: The stage of the recipe to use for oneshot. # Dataset arguments - :param dataset: The name of the dataset to use (via the datasets - library). + :param dataset: The dataset to use for calibration. Can be a dataset name + (str, via the datasets library), a HuggingFace Dataset or DatasetDict, + or a pre-built PyTorch DataLoader. When a DataLoader is passed, the + internal dataset-to-dataloader conversion is skipped. :param dataset_config_name: The configuration name of the dataset to use. :param dataset_path: Path to a custom dataset. Supports json, csv, dvc. @@ -393,7 +395,6 @@ def oneshot( :param sequential_prefetch: When using the sequential pipeline, prefetch the next batch in a background thread to overlap onload with forward. Default False; set True for faster calibration when GPU memory allows. - # Miscellaneous arguments :param output_dir: Path to save the output model after calibration. Nothing is saved if None.