Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ class DatasetArguments(CustomDatasetArguments):
"than one gpu. Default is cpu."
},
)
sequential_weight_offload_device: str = field(
default="cpu",
metadata={
"help": "Device used to offload model weights in the sequential "
"pipeline. Set to `none` to disable weight offloading and keep "
"weights on the main execution device. Default is cpu."
},
)
quantization_aware_calibration: bool = field(
default=True,
metadata={
Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def oneshot(
],
sequential_targets: list[str] | None = None,
sequential_offload_device: str = "cpu",
sequential_weight_offload_device: str = "cpu",
quantization_aware_calibration: bool = True,
sequential_prefetch: bool = False,
# Miscellaneous arguments
Expand Down Expand Up @@ -388,6 +389,9 @@ def oneshot(
:param sequential_offload_device: Device used to offload intermediate activations
between sequential layers. It is recommended to use `cuda:1` if using more
than one gpu. Default is cpu.
:param sequential_weight_offload_device: Device used to offload model weights
in the sequential pipeline. Set to `none` to disable weight offloading and
keep weights on the main execution device. Default is cpu.
:param quantization_aware_calibration: Whether to enable quantization-aware
calibration in the sequential pipeline. When True, quantization is applied
during forward pass in calibration. When False, quantization is disabled
Expand Down
9 changes: 9 additions & 0 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,19 @@ def dispatch_for_sequential(
if available. Removes any existing hooks.

:param model: model to dispatch
:param onload_device: device used for forward passes
:param offload_device: device to offload weights to. ``"none"`` or ``None``
disables weight offloading and keeps the model on ``onload_device``.
:return: dispatched model
"""
if onload_device is None:
onload_device = get_main_device()
if isinstance(offload_device, str) and offload_device.lower() == "none":
offload_device = None
if offload_device is None:
remove_hook_from_module(model, recurse=True)
model.to(onload_device)
return model
return offload_model(model, onload_device, offload_device)


Expand Down
8 changes: 5 additions & 3 deletions src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ def __call__(

# prepare model for sequential onloading
onload_device = get_main_device()
offload_device = torch.device(dataset_args.sequential_offload_device)
dispatch_for_sequential(model, onload_device)
activation_offload_device = torch.device(dataset_args.sequential_offload_device)
dispatch_for_sequential(
model, onload_device, dataset_args.sequential_weight_offload_device
)

# prepare to trace subgraphs
modifiers = session.lifecycle.recipe.modifiers
Expand Down Expand Up @@ -126,7 +128,7 @@ def __call__(

# prepare intermediates cache
activations = IntermediatesCache.from_dataloader(
dataloader, onload_device, offload_device
dataloader, onload_device, activation_offload_device
)

# Populate loss_masks once from cached activations for AWQ masking support
Expand Down
21 changes: 20 additions & 1 deletion tests/llmcompressor/pipelines/sequential/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest
import torch

from llmcompressor.pipelines.sequential.helpers import get_sequential_ancestors
from llmcompressor.pipelines.sequential.helpers import (
dispatch_for_sequential,
get_sequential_ancestors,
)


class DummyModel(torch.nn.Module):
Expand All @@ -22,3 +26,18 @@ def test_get_sequential_ancestors():
assert get_sequential_ancestors(model, {model.fc}) == {model}
assert get_sequential_ancestors(model, {model.seq[0]}) == {model, model.seq}
assert get_sequential_ancestors(model, {model.seq[1]}) == {model, model.seq}


@pytest.mark.parametrize("offload_device", [None, "none", "None"])
def test_dispatch_for_sequential_no_offload(offload_device):
"""When offload_device is None or 'none', the model should stay on the
onload device without any offloading hooks."""
model = DummyModel()
result = dispatch_for_sequential(
model, onload_device="cpu", offload_device=offload_device
)
assert result is model
for name, param in model.named_parameters():
assert param.device == torch.device(
"cpu"
), f"{name} on {param.device}, expected cpu"
Loading