diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 0be09bd06..fddcafd82 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -18,7 +18,7 @@ from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union from urllib.parse import urlparse import numpy @@ -27,6 +27,11 @@ from loguru import logger from transformers import PreTrainedModel +from llmcompressor.utils import get_embeddings + +if TYPE_CHECKING: + pass + __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -65,6 +70,7 @@ "DisableQuantization", "eval_context", "calibration_forward_context", + "disable_lm_head", "patch_attr", "disable_hf_kernels", "DISABLE_QAC_MODIFIERS", @@ -1049,11 +1055,45 @@ def calibration_forward_context(model: torch.nn.Module): - Disable the KV cache - Disable train mode and enable eval mode - Disable hf kernels which could bypass hooks + - Disable lm head (input and weights can still be calibrated, output will be meta) + """ + with contextlib.ExitStack() as stack: + stack.enter_context(torch.no_grad()) + stack.enter_context(disable_cache(model)) + stack.enter_context(eval_context(model)) + stack.enter_context(disable_hf_kernels(model)) + stack.enter_context(disable_lm_head(model)) + yield + + +@contextlib.contextmanager +def disable_lm_head(model: torch.nn.Module): + """ + Disable the lm_head of a model by moving it to the meta device. This function + does not untie parameters and restores the model proper loading upon exit """ - with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels( - model - ): + _, lm_head = get_embeddings(model) + if lm_head is not None: + logger.warning( + f"Attempted to disable lm_head of instance {model.__class__.__name__}, " + "but was unable to to find lm_head. This may lead to unexpected OOM." + ) yield + return + + elif not isinstance(lm_head, torch.nn.Linear): + logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}") + yield + return + + else: + dummy_weight = lm_head.weight.to("meta") + + def dummy_forward(self, input: torch.Tensor) -> torch.Tensor: + return input.to("meta") @ dummy_weight.T + + with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)): + yield @contextlib.contextmanager diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index cb474c1f5..5e47bebe3 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -5,23 +5,23 @@ from transformers import ( AutoModelForCausalLM, MllamaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, ) +from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential from llmcompressor.utils import ( ALL_TOKEN, DisableQuantization, calibration_forward_context, convert_to_bool, disable_cache, + disable_lm_head, flatten_iterable, getattr_chain, interpolate, patch_attr, validate_str_iterable, ) -from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download from tests.testing_utils import requires_gpu @@ -149,10 +149,8 @@ def test_DisableQuantization(): @pytest.mark.unit def test_calibration_forward_context(): - class DummyModel(PreTrainedModel): - config_class = PretrainedConfig - - model = DummyModel(PretrainedConfig()) + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") model.config.use_cache = True model.train() @@ -160,9 +158,12 @@ class DummyModel(PreTrainedModel): assert not torch.is_grad_enabled() assert not model.config.use_cache assert not model.training + assert model.lm_head.forward.__name__ == "dummy_forward" + assert torch.is_grad_enabled() assert model.config.use_cache assert model.training + assert model.lm_head.forward.__name__ == "forward" @pytest.mark.unit @@ -203,3 +204,29 @@ def test_disable_cache(model_cls, model_stub): output = model(**inputs) assert output.past_key_values is not None + + +@requires_gpu +@pytest.mark.parametrize("offload", ["sequential", "basic", "none"]) +def test_disable_lm_head(offload): + model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") + if offload == "sequential": + dispatch_for_sequential(model) + if offload == "basic": + dispatch_for_generation(model) + if offload == "none": + model = model.to("cuda") + + lm_input_device = None + + def hook(module, args): + nonlocal lm_input_device + lm_input_device = args[0].device + + model.lm_head.register_forward_pre_hook(hook) + + with disable_lm_head(model): + input = {key: value.to("cuda") for key, value in model.dummy_inputs.items()} + output = model(**input) + assert lm_input_device == torch.device("cuda:0") + assert output.logits.device == torch.device("meta")