Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import OrderedDict
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union
from urllib.parse import urlparse

import numpy
Expand All @@ -27,6 +27,11 @@
from loguru import logger
from transformers import PreTrainedModel

from llmcompressor.utils import get_embeddings

if TYPE_CHECKING:
pass

__all__ = [
"ALL_TOKEN",
"ALL_PRUNABLE_TOKEN",
Expand Down Expand Up @@ -65,6 +70,7 @@
"DisableQuantization",
"eval_context",
"calibration_forward_context",
"disable_lm_head",
"patch_attr",
"disable_hf_kernels",
"DISABLE_QAC_MODIFIERS",
Expand Down Expand Up @@ -1049,11 +1055,45 @@ def calibration_forward_context(model: torch.nn.Module):
- Disable the KV cache
- Disable train mode and enable eval mode
- Disable hf kernels which could bypass hooks
- Disable lm head (input and weights can still be calibrated, output will be meta)
"""
with contextlib.ExitStack() as stack:
stack.enter_context(torch.no_grad())
stack.enter_context(disable_cache(model))
stack.enter_context(eval_context(model))
stack.enter_context(disable_hf_kernels(model))
stack.enter_context(disable_lm_head(model))
yield


@contextlib.contextmanager
def disable_lm_head(model: torch.nn.Module):
"""
Disable the lm_head of a model by moving it to the meta device. This function
does not untie parameters and restores the model proper loading upon exit
"""
with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels(
model
):
_, lm_head = get_embeddings(model)
if lm_head is not None:
logger.warning(
f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
"but was unable to to find lm_head. This may lead to unexpected OOM."
)
yield
return

elif not isinstance(lm_head, torch.nn.Linear):
logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}")
yield
return

else:
dummy_weight = lm_head.weight.to("meta")

def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
return input.to("meta") @ dummy_weight.T

with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)):
yield
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still getting used to the code paradigms here, this nested context manager design is an interesting approach. I'm pretty sure i had similar situations in torchao and I never would have considered this.

Copy link
Collaborator Author

@kylesayrs kylesayrs Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I love the context manager pattern, as it makes it reminds implementers to clean up any side effects which might affect downstream code.



@contextlib.contextmanager
Expand Down
41 changes: 34 additions & 7 deletions tests/llmcompressor/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
from transformers import (
AutoModelForCausalLM,
MllamaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
)

from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential
from llmcompressor.utils import (
ALL_TOKEN,
DisableQuantization,
calibration_forward_context,
convert_to_bool,
disable_cache,
disable_lm_head,
flatten_iterable,
getattr_chain,
interpolate,
patch_attr,
validate_str_iterable,
)
from llmcompressor.utils.dev import skip_weights_download
from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download
from tests.testing_utils import requires_gpu


Expand Down Expand Up @@ -149,20 +149,21 @@ def test_DisableQuantization():

@pytest.mark.unit
def test_calibration_forward_context():
class DummyModel(PreTrainedModel):
config_class = PretrainedConfig

model = DummyModel(PretrainedConfig())
with skip_weights_download():
model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2")
model.config.use_cache = True
model.train()

with calibration_forward_context(model):
assert not torch.is_grad_enabled()
assert not model.config.use_cache
assert not model.training
assert model.lm_head.forward.__name__ == "dummy_forward"

assert torch.is_grad_enabled()
assert model.config.use_cache
assert model.training
assert model.lm_head.forward.__name__ == "forward"


@pytest.mark.unit
Expand Down Expand Up @@ -203,3 +204,29 @@ def test_disable_cache(model_cls, model_stub):

output = model(**inputs)
assert output.past_key_values is not None


@requires_gpu
@pytest.mark.parametrize("offload", ["sequential", "basic", "none"])
def test_disable_lm_head(offload):
model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2")
if offload == "sequential":
dispatch_for_sequential(model)
if offload == "basic":
dispatch_for_generation(model)
if offload == "none":
model = model.to("cuda")

lm_input_device = None

def hook(module, args):
nonlocal lm_input_device
lm_input_device = args[0].device

model.lm_head.register_forward_pre_hook(hook)

with disable_lm_head(model):
input = {key: value.to("cuda") for key, value in model.dummy_inputs.items()}
output = model(**input)
assert lm_input_device == torch.device("cuda:0")
assert output.logits.device == torch.device("meta")
Loading