Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from dataclasses import dataclass, field
from typing import Callable

from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader


@dataclass
class DVCDatasetArguments:
Expand Down Expand Up @@ -101,12 +104,13 @@ class DatasetArguments(CustomDatasetArguments):
arguments to be able to specify them on the command line
"""

dataset: str | None = field(
dataset: str | Dataset | DatasetDict | DataLoader | None = field(
default=None,
metadata={
"help": (
"The name of the dataset to use (via the datasets library). "
"Supports input as a string or DatasetDict from HF"
"The dataset to use for calibration. Supports a dataset name "
"(str, via the datasets library), a DatasetDict from HF, or a "
"pre-built PyTorch DataLoader."
)
},
)
Expand Down
14 changes: 11 additions & 3 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,24 @@ def _get_split_name(inp_str):
def get_calibration_dataloader(
dataset_args: DatasetArguments,
processor: Processor,
) -> torch.utils.data.DataLoader:
) -> DataLoader | None:
"""
Get the dataloader used for oneshot calibration.

If dataset_args.dataset is already a PyTorch DataLoader,
it is returned directly, bypassing dataset loading and tokenization.

:param dataset_args: DatasetArguments that contains the dataset parameters.
:param processor: Processor or the tokenizer of the model.
:return: PyTorch dataloader object that contains the calibration dataset.
:return: PyTorch dataloader object that contains the calibration
dataset, or None for data-free flows.
"""
if dataset_args.dataset is None:
# weight-only quantization or dynamic quantization
return
return None

if isinstance(dataset_args.dataset, DataLoader):
return dataset_args.dataset

datasets = get_processed_dataset(
dataset_args=dataset_args,
Expand Down
9 changes: 5 additions & 4 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def oneshot(
clear_sparse_session: bool = False,
stage: str | None = None,
# Dataset arguments
dataset: str | Dataset | DatasetDict | None = None,
dataset: str | Dataset | DatasetDict | DataLoader | None = None,
dataset_config_name: str | None = None,
dataset_path: str | None = None,
splits: str | list[str] | dict[str, str] | None = None,
Expand Down Expand Up @@ -340,8 +340,10 @@ def oneshot(
:param stage: The stage of the recipe to use for oneshot.

# Dataset arguments
:param dataset: The name of the dataset to use (via the datasets
library).
:param dataset: The dataset to use for calibration. Can be a dataset name
(str, via the datasets library), a HuggingFace Dataset or DatasetDict,
or a pre-built PyTorch DataLoader. When a DataLoader is passed, the
internal dataset-to-dataloader conversion is skipped.
:param dataset_config_name: The configuration name of the dataset
to use.
:param dataset_path: Path to a custom dataset. Supports json, csv, dvc.
Expand Down Expand Up @@ -393,7 +395,6 @@ def oneshot(
:param sequential_prefetch: When using the sequential pipeline, prefetch the
next batch in a background thread to overlap onload with forward. Default
False; set True for faster calibration when GPU memory allows.

# Miscellaneous arguments
:param output_dir: Path to save the output model after calibration.
Nothing is saved if None.
Expand Down