From 0dbc613d12727a4e4eef54ed764387f99e27a48d Mon Sep 17 00:00:00 2001 From: Soren Dreano Date: Fri, 20 Feb 2026 15:33:33 +0100 Subject: [PATCH 01/10] Add support for passing a custom DataLoader to oneshot() Allow users to pass a pre-built PyTorch DataLoader directly via the `dataloader` parameter, bypassing the internal dataset-to-dataloader conversion. This is useful for custom data pipelines where users already have a prepared DataLoader and don't need get_calibration_dataloader(). --- src/llmcompressor/entrypoints/oneshot.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 34fc06bd43..01b2a8cd3f 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -156,6 +156,9 @@ def __init__( level="DEBUG", ) + # Pop dataloader before parse_args (HfArgumentParser rejects unknown fields) + self.dataloader = kwargs.pop("dataloader", None) + model_args, dataset_args, recipe_args, output_dir = parse_args(**kwargs) self.model_args = model_args @@ -182,9 +185,12 @@ def __call__(self): """ - calibration_dataloader = get_calibration_dataloader( - self.dataset_args, self.processor - ) + if self.dataloader is not None: + calibration_dataloader = self.dataloader + else: + calibration_dataloader = get_calibration_dataloader( + self.dataset_args, self.processor + ) self.apply_recipe_modifiers( calibration_dataloader=calibration_dataloader, recipe_stage=self.recipe_args.stage, @@ -300,6 +306,7 @@ def oneshot( sequential_offload_device: str = "cpu", quantization_aware_calibration: bool = True, sequential_prefetch: bool = False, + dataloader: DataLoader | None = None, # Miscellaneous arguments output_dir: str | None = None, log_dir: str | None = None, @@ -393,6 +400,8 @@ def oneshot( :param sequential_prefetch: When using the sequential pipeline, prefetch the next batch in a background thread to overlap onload with forward. Default False; set True for faster calibration when GPU memory allows. + :param dataloader: A pre-built PyTorch DataLoader for calibration. If provided, + skips the internal dataset-to-dataloader conversion and uses this directly. # Miscellaneous arguments :param output_dir: Path to save the output model after calibration. From 5c8fa87187800131b0edf50533d63c3bf76ad47c Mon Sep 17 00:00:00 2001 From: Soren Dreano Date: Mon, 23 Feb 2026 19:35:16 +0100 Subject: [PATCH 02/10] use dataset to pass the dataloader --- src/llmcompressor/args/dataset_arguments.py | 9 ++++++--- src/llmcompressor/entrypoints/oneshot.py | 19 +++++++------------ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 86bd88fd91..dc47b54b46 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -10,6 +10,8 @@ from dataclasses import dataclass, field from typing import Callable +from torch.utils.data import DataLoader + @dataclass class DVCDatasetArguments: @@ -101,12 +103,13 @@ class DatasetArguments(CustomDatasetArguments): arguments to be able to specify them on the command line """ - dataset: str | None = field( + dataset: str | DataLoader | None = field( default=None, metadata={ "help": ( - "The name of the dataset to use (via the datasets library). " - "Supports input as a string or DatasetDict from HF" + "The dataset to use for calibration. Supports a dataset name " + "(str, via the datasets library), a DatasetDict from HF, or a " + "pre-built PyTorch DataLoader." ) }, ) diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 01b2a8cd3f..94d9030c03 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -156,9 +156,6 @@ def __init__( level="DEBUG", ) - # Pop dataloader before parse_args (HfArgumentParser rejects unknown fields) - self.dataloader = kwargs.pop("dataloader", None) - model_args, dataset_args, recipe_args, output_dir = parse_args(**kwargs) self.model_args = model_args @@ -185,8 +182,8 @@ def __call__(self): """ - if self.dataloader is not None: - calibration_dataloader = self.dataloader + if isinstance(self.dataset_args.dataset, DataLoader): + calibration_dataloader = self.dataset_args.dataset else: calibration_dataloader = get_calibration_dataloader( self.dataset_args, self.processor @@ -269,7 +266,7 @@ def oneshot( clear_sparse_session: bool = False, stage: str | None = None, # Dataset arguments - dataset: str | Dataset | DatasetDict | None = None, + dataset: str | Dataset | DatasetDict | DataLoader | None = None, dataset_config_name: str | None = None, dataset_path: str | None = None, splits: str | list[str] | dict[str, str] | None = None, @@ -306,7 +303,6 @@ def oneshot( sequential_offload_device: str = "cpu", quantization_aware_calibration: bool = True, sequential_prefetch: bool = False, - dataloader: DataLoader | None = None, # Miscellaneous arguments output_dir: str | None = None, log_dir: str | None = None, @@ -347,8 +343,10 @@ def oneshot( :param stage: The stage of the recipe to use for oneshot. # Dataset arguments - :param dataset: The name of the dataset to use (via the datasets - library). + :param dataset: The dataset to use for calibration. Can be a dataset name + (str, via the datasets library), a HuggingFace Dataset or DatasetDict, + or a pre-built PyTorch DataLoader. When a DataLoader is passed, the + internal dataset-to-dataloader conversion is skipped. :param dataset_config_name: The configuration name of the dataset to use. :param dataset_path: Path to a custom dataset. Supports json, csv, dvc. @@ -400,9 +398,6 @@ def oneshot( :param sequential_prefetch: When using the sequential pipeline, prefetch the next batch in a background thread to overlap onload with forward. Default False; set True for faster calibration when GPU memory allows. - :param dataloader: A pre-built PyTorch DataLoader for calibration. If provided, - skips the internal dataset-to-dataloader conversion and uses this directly. - # Miscellaneous arguments :param output_dir: Path to save the output model after calibration. Nothing is saved if None. From 2182a21ff730107421751e38d4ba4f15124b489f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Dr=C3=A9ano?= <71752785+SorenDreano@users.noreply.github.com> Date: Tue, 24 Feb 2026 10:10:52 +0100 Subject: [PATCH 03/10] Update src/llmcompressor/args/dataset_arguments.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Brian Dellabetta Signed-off-by: Sören Dréano <71752785+SorenDreano@users.noreply.github.com> --- src/llmcompressor/args/dataset_arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index dc47b54b46..5b0a3b4329 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -103,7 +103,7 @@ class DatasetArguments(CustomDatasetArguments): arguments to be able to specify them on the command line """ - dataset: str | DataLoader | None = field( + dataset: str | Dataset | DatasetDict | DataLoader | None = field( default=None, metadata={ "help": ( From ed9ceef63a13220f885ff182ba9b7dd89bd10519 Mon Sep 17 00:00:00 2001 From: Soren Dreano Date: Tue, 24 Feb 2026 10:23:52 +0100 Subject: [PATCH 04/10] add imports --- src/llmcompressor/args/dataset_arguments.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 5b0a3b4329..244201689c 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -12,6 +12,8 @@ from torch.utils.data import DataLoader +from datasets import Dataset, DatasetDict + @dataclass class DVCDatasetArguments: From 56ea6e6ab9f6acd3656dda01bd3022ff49da8b60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Dr=C3=A9ano?= <71752785+SorenDreano@users.noreply.github.com> Date: Tue, 24 Feb 2026 16:54:13 +0100 Subject: [PATCH 05/10] Update src/llmcompressor/args/dataset_arguments.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Brian Dellabetta Signed-off-by: Sören Dréano <71752785+SorenDreano@users.noreply.github.com> --- src/llmcompressor/args/dataset_arguments.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 244201689c..519bb33fdd 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -8,11 +8,12 @@ """ from dataclasses import dataclass, field -from typing import Callable +from __future__ import annotations +from typing import Callable, TYPE_CHECKING -from torch.utils.data import DataLoader - -from datasets import Dataset, DatasetDict +if TYPE_CHECKING: + from torch.utils.data import DataLoader + from datasets import Dataset, DatasetDict @dataclass From d660afb88f169e9d3e74de75fc9a7eda6791dd07 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 24 Feb 2026 11:12:36 -0500 Subject: [PATCH 06/10] isort --- src/llmcompressor/args/dataset_arguments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 519bb33fdd..507a8af190 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -6,11 +6,11 @@ sources and processing pipelines. Supports various input formats including HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ - -from dataclasses import dataclass, field from __future__ import annotations from typing import Callable, TYPE_CHECKING +from dataclasses import dataclass, field + if TYPE_CHECKING: from torch.utils.data import DataLoader from datasets import Dataset, DatasetDict From f3cf9f8b4e914e36a5e3a33ad6a1a297548642f3 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 24 Feb 2026 16:17:38 +0000 Subject: [PATCH 07/10] stylefix Signed-off-by: Brian Dellabetta --- src/llmcompressor/args/dataset_arguments.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 507a8af190..535f1fd7c9 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -6,14 +6,15 @@ sources and processing pipelines. Supports various input formats including HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ + from __future__ import annotations -from typing import Callable, TYPE_CHECKING from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: - from torch.utils.data import DataLoader from datasets import Dataset, DatasetDict + from torch.utils.data import DataLoader @dataclass From 34fe3e3920be89e0f2336cf033e08471852c1ec9 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 24 Feb 2026 16:37:09 +0000 Subject: [PATCH 08/10] fix Signed-off-by: Brian Dellabetta --- src/llmcompressor/args/dataset_arguments.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 535f1fd7c9..60705744e2 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,14 +7,11 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ -from __future__ import annotations - from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable +from typing import Callable -if TYPE_CHECKING: - from datasets import Dataset, DatasetDict - from torch.utils.data import DataLoader +from datasets import Dataset, DatasetDict +from torch.utils.data import DataLoader @dataclass From 0b3555eac98a3f3fec2819f696178c5bae6b74a0 Mon Sep 17 00:00:00 2001 From: Soren Dreano Date: Tue, 24 Feb 2026 18:48:26 +0100 Subject: [PATCH 09/10] move the dataloader isinstance --- src/llmcompressor/datasets/utils.py | 20 ++++++++++++++------ src/llmcompressor/entrypoints/oneshot.py | 9 +++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 69d1b47415..c0da7d4375 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -100,16 +100,24 @@ def _get_split_name(inp_str): def get_calibration_dataloader( dataset_args: DatasetArguments, processor: Processor, -) -> torch.utils.data.DataLoader: +) -> DataLoader | None: """ Get the dataloader used for oneshot calibration. + + If dataset_args.dataset is already a PyTorch DataLoader, + it is returned directly, bypassing dataset loading and tokenization. + :param dataset_args: DatasetArguments that contains the dataset parameters. :param processor: Processor or the tokenizer of the model. - :return: PyTorch dataloader object that contains the calibration dataset. + :return: PyTorch dataloader object that contains the calibration + dataset, or None for data-free flows. """ if dataset_args.dataset is None: # weight-only quantization or dynamic quantization - return + return None + + if isinstance(dataset_args.dataset, DataLoader): + return dataset_args.dataset datasets = get_processed_dataset( dataset_args=dataset_args, @@ -414,9 +422,9 @@ def get_rank_partition(split: str, num_samples: int) -> str: we give each device at least S//D samples and distribute the remaining samples as evenly as possible across all devices """ - assert ( - "[" not in split - ), "Split string should not already contain partitioning brackets" + assert "[" not in split, ( + "Split string should not already contain partitioning brackets" + ) start, end = _get_partition_start_end( num_samples, dist.get_rank(), dist.get_world_size() diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 94d9030c03..9ab1df68e9 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -182,12 +182,9 @@ def __call__(self): """ - if isinstance(self.dataset_args.dataset, DataLoader): - calibration_dataloader = self.dataset_args.dataset - else: - calibration_dataloader = get_calibration_dataloader( - self.dataset_args, self.processor - ) + calibration_dataloader = get_calibration_dataloader( + self.dataset_args, self.processor + ) self.apply_recipe_modifiers( calibration_dataloader=calibration_dataloader, recipe_stage=self.recipe_args.stage, From 5140506c9af8fec433f430c60e147bf05c3c792f Mon Sep 17 00:00:00 2001 From: Soren Dreano Date: Tue, 24 Feb 2026 18:54:04 +0100 Subject: [PATCH 10/10] fix formatting --- src/llmcompressor/datasets/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index c0da7d4375..3478f08055 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -422,9 +422,9 @@ def get_rank_partition(split: str, num_samples: int) -> str: we give each device at least S//D samples and distribute the remaining samples as evenly as possible across all devices """ - assert "[" not in split, ( - "Split string should not already contain partitioning brackets" - ) + assert ( + "[" not in split + ), "Split string should not already contain partitioning brackets" start, end = _get_partition_start_end( num_samples, dist.get_rank(), dist.get_world_size()