Skip to content
Merged
15 changes: 12 additions & 3 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def __init__(
level="DEBUG",
)

# Pop dataloader before parse_args (HfArgumentParser rejects unknown fields)
self.dataloader = kwargs.pop("dataloader", None)

model_args, dataset_args, recipe_args, output_dir = parse_args(**kwargs)

self.model_args = model_args
Expand All @@ -182,9 +185,12 @@ def __call__(self):

"""

calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
if self.dataloader is not None:
calibration_dataloader = self.dataloader
else:
calibration_dataloader = get_calibration_dataloader(
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
recipe_stage=self.recipe_args.stage,
Expand Down Expand Up @@ -300,6 +306,7 @@ def oneshot(
sequential_offload_device: str = "cpu",
quantization_aware_calibration: bool = True,
sequential_prefetch: bool = False,
dataloader: DataLoader | None = None,
# Miscellaneous arguments
output_dir: str | None = None,
log_dir: str | None = None,
Expand Down Expand Up @@ -393,6 +400,8 @@ def oneshot(
:param sequential_prefetch: When using the sequential pipeline, prefetch the
next batch in a background thread to overlap onload with forward. Default
False; set True for faster calibration when GPU memory allows.
:param dataloader: A pre-built PyTorch DataLoader for calibration. If provided,
skips the internal dataset-to-dataloader conversion and uses this directly.

# Miscellaneous arguments
:param output_dir: Path to save the output model after calibration.
Expand Down
Loading