Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from dataclasses import dataclass, field
from typing import Callable

from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader


@dataclass
class DVCDatasetArguments:
Expand Down Expand Up @@ -101,12 +104,13 @@ class DatasetArguments(CustomDatasetArguments):
arguments to be able to specify them on the command line
"""

dataset: str | None = field(
dataset: str | Dataset | DatasetDict | DataLoader | None = field(
default=None,
metadata={
"help": (
"The name of the dataset to use (via the datasets library). "
"Supports input as a string or DatasetDict from HF"
"The dataset to use for calibration. Supports a dataset name "
"(str, via the datasets library), a DatasetDict from HF, or a "
"pre-built PyTorch DataLoader."
)
},
)
Expand Down
18 changes: 11 additions & 7 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,12 @@ def __call__(self):

"""

calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
if isinstance(self.dataset_args.dataset, DataLoader):
calibration_dataloader = self.dataset_args.dataset
else:
calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
recipe_stage=self.recipe_args.stage,
Expand Down Expand Up @@ -263,7 +266,7 @@ def oneshot(
clear_sparse_session: bool = False,
stage: str | None = None,
# Dataset arguments
dataset: str | Dataset | DatasetDict | None = None,
dataset: str | Dataset | DatasetDict | DataLoader | None = None,
dataset_config_name: str | None = None,
dataset_path: str | None = None,
splits: str | list[str] | dict[str, str] | None = None,
Expand Down Expand Up @@ -340,8 +343,10 @@ def oneshot(
:param stage: The stage of the recipe to use for oneshot.

# Dataset arguments
:param dataset: The name of the dataset to use (via the datasets
library).
:param dataset: The dataset to use for calibration. Can be a dataset name
(str, via the datasets library), a HuggingFace Dataset or DatasetDict,
or a pre-built PyTorch DataLoader. When a DataLoader is passed, the
internal dataset-to-dataloader conversion is skipped.
:param dataset_config_name: The configuration name of the dataset
to use.
:param dataset_path: Path to a custom dataset. Supports json, csv, dvc.
Expand Down Expand Up @@ -393,7 +398,6 @@ def oneshot(
:param sequential_prefetch: When using the sequential pipeline, prefetch the
next batch in a background thread to overlap onload with forward. Default
False; set True for faster calibration when GPU memory allows.

# Miscellaneous arguments
:param output_dir: Path to save the output model after calibration.
Nothing is saved if None.
Expand Down
Loading