diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 60705744e2..df1b17aed5 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -242,6 +242,14 @@ class DatasetArguments(CustomDatasetArguments): "than one gpu. Default is cpu." }, ) + sequential_weight_offload_device: str = field( + default="cpu", + metadata={ + "help": "Device used to offload model weights in the sequential " + "pipeline. Set to `none` to disable weight offloading and keep " + "weights on the main execution device. Default is cpu." + }, + ) quantization_aware_calibration: bool = field( default=True, metadata={ diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 9ab1df68e9..1d6e51f1fe 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -298,6 +298,7 @@ def oneshot( ], sequential_targets: list[str] | None = None, sequential_offload_device: str = "cpu", + sequential_weight_offload_device: str = "cpu", quantization_aware_calibration: bool = True, sequential_prefetch: bool = False, # Miscellaneous arguments @@ -388,6 +389,9 @@ def oneshot( :param sequential_offload_device: Device used to offload intermediate activations between sequential layers. It is recommended to use `cuda:1` if using more than one gpu. Default is cpu. + :param sequential_weight_offload_device: Device used to offload model weights + in the sequential pipeline. Set to `none` to disable weight offloading and + keep weights on the main execution device. Default is cpu. :param quantization_aware_calibration: Whether to enable quantization-aware calibration in the sequential pipeline. When True, quantization is applied during forward pass in calibration. When False, quantization is disabled diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index fbbd7c9d51..1658e6ea59 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -527,10 +527,19 @@ def dispatch_for_sequential( if available. Removes any existing hooks. :param model: model to dispatch + :param onload_device: device used for forward passes + :param offload_device: device to offload weights to. ``"none"`` or ``None`` + disables weight offloading and keeps the model on ``onload_device``. :return: dispatched model """ if onload_device is None: onload_device = get_main_device() + if isinstance(offload_device, str) and offload_device.strip().lower() == "none": + offload_device = None + if offload_device is None: + remove_hook_from_module(model, recurse=True) + model.to(onload_device) + return model return offload_model(model, onload_device, offload_device) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index a16693b1a0..f32bf91ee7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -96,8 +96,10 @@ def __call__( # prepare model for sequential onloading onload_device = get_main_device() - offload_device = torch.device(dataset_args.sequential_offload_device) - dispatch_for_sequential(model, onload_device) + activation_offload_device = torch.device(dataset_args.sequential_offload_device) + dispatch_for_sequential( + model, onload_device, dataset_args.sequential_weight_offload_device + ) # prepare to trace subgraphs modifiers = session.lifecycle.recipe.modifiers @@ -126,7 +128,7 @@ def __call__( # prepare intermediates cache activations = IntermediatesCache.from_dataloader( - dataloader, onload_device, offload_device + dataloader, onload_device, activation_offload_device ) # Populate loss_masks once from cached activations for AWQ masking support diff --git a/tests/llmcompressor/pipelines/sequential/test_helpers.py b/tests/llmcompressor/pipelines/sequential/test_helpers.py index 7e2c84795a..d1ef0bd5b7 100644 --- a/tests/llmcompressor/pipelines/sequential/test_helpers.py +++ b/tests/llmcompressor/pipelines/sequential/test_helpers.py @@ -1,6 +1,10 @@ +import pytest import torch -from llmcompressor.pipelines.sequential.helpers import get_sequential_ancestors +from llmcompressor.pipelines.sequential.helpers import ( + dispatch_for_sequential, + get_sequential_ancestors, +) class DummyModel(torch.nn.Module): @@ -22,3 +26,18 @@ def test_get_sequential_ancestors(): assert get_sequential_ancestors(model, {model.fc}) == {model} assert get_sequential_ancestors(model, {model.seq[0]}) == {model, model.seq} assert get_sequential_ancestors(model, {model.seq[1]}) == {model, model.seq} + + +@pytest.mark.parametrize("offload_device", [None, "none", "None"]) +def test_dispatch_for_sequential_no_offload(offload_device): + """When offload_device is None or 'none', the model should stay on the + onload device without any offloading hooks.""" + model = DummyModel() + result = dispatch_for_sequential( + model, onload_device="cpu", offload_device=offload_device + ) + assert result is model + for name, param in model.named_parameters(): + assert param.device == torch.device( + "cpu" + ), f"{name} on {param.device}, expected cpu"