Skip to content

Commit 61ca3ce

Browse files
committed
always disable
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f8c15dc commit 61ca3ce

File tree

4 files changed

+51
-48
lines changed

4 files changed

+51
-48
lines changed

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66
from compressed_tensors.utils import get_execution_device
77
from torch.utils.data.dataloader import DataLoader
88

9-
from llmcompressor.core import LifecycleCallbacks, active_session
9+
from llmcompressor.core import LifecycleCallbacks
1010
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
1111
from llmcompressor.pipelines.registry import CalibrationPipeline
1212
from llmcompressor.pytorch.utils.helpers import tensors_to_device
1313
from llmcompressor.utils import (
1414
calibration_forward_context,
15-
disable_lm_head,
1615
dispatch_for_generation,
17-
requires_lm_head_calibration,
1816
)
1917

2018
if TYPE_CHECKING:
@@ -43,20 +41,13 @@ def __call__(
4341
:param dataloader: loads data for calibration
4442
:param dataset_args: dataset arguments relevant to pipelines
4543
"""
46-
session = active_session()
47-
modifiers = session.lifecycle.recipe.modifiers
48-
4944
dispatch_for_generation(model) # basic dispatch is identical to generation
5045
model_device = get_execution_device(model)
5146

5247
LifecycleCallbacks.calibration_epoch_start()
5348

5449
with contextlib.ExitStack() as stack:
5550
stack.enter_context(calibration_forward_context(model))
56-
# Optionally disable lm_head
57-
if not requires_lm_head_calibration(model, modifiers):
58-
stack.enter_context(disable_lm_head(model))
59-
6051
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
6152
batch = apply_pad_mask_to_batch(batch)
6253
batch = tensors_to_device(batch, model_device)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
DISABLE_QAC_MODIFIERS,
2020
DisableQuantization,
2121
calibration_forward_context,
22-
disable_lm_head,
23-
requires_lm_head_calibration,
2422
)
2523

2624
if TYPE_CHECKING:
@@ -91,8 +89,6 @@ def __call__(
9189
if not dataset_args.quantization_aware_calibration or disable_qac:
9290
stack.enter_context(DisableQuantization(model))
9391
# Optionally disable lm_head
94-
if not requires_lm_head_calibration(model, modifiers):
95-
stack.enter_context(disable_lm_head(model))
9692

9793
# prepare intermediates cache
9894
activations = IntermediatesCache.from_dataloader(dataloader, model_device)

src/llmcompressor/utils/helpers.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,14 @@
2323

2424
import numpy
2525
import torch
26-
from compressed_tensors import has_offloaded_params, match_named_modules
2726
from compressed_tensors.quantization import disable_quantization, enable_quantization
2827
from loguru import logger
2928
from transformers import PreTrainedModel
3029

31-
from llmcompressor.utils import get_embeddings, targets_embeddings
30+
from llmcompressor.utils import get_embeddings
3231

3332
if TYPE_CHECKING:
34-
from llmcompressor.modifiers import Modifier
33+
pass
3534

3635
__all__ = [
3736
"ALL_TOKEN",
@@ -72,7 +71,6 @@
7271
"eval_context",
7372
"calibration_forward_context",
7473
"disable_lm_head",
75-
"requires_lm_head_calibration",
7674
"patch_attr",
7775
"disable_hf_kernels",
7876
"DISABLE_QAC_MODIFIERS",
@@ -1057,12 +1055,14 @@ def calibration_forward_context(model: torch.nn.Module):
10571055
- Disable the KV cache
10581056
- Disable train mode and enable eval mode
10591057
- Disable hf kernels which could bypass hooks
1058+
- Disable lm head (input and weights can still be calibrated, output will be meta)
10601059
"""
10611060
with contextlib.ExitStack() as stack:
10621061
stack.enter_context(torch.no_grad())
10631062
stack.enter_context(disable_cache(model))
10641063
stack.enter_context(eval_context(model))
10651064
stack.enter_context(disable_hf_kernels(model))
1065+
stack.enter_context(disable_lm_head(model))
10661066
yield
10671067

10681068

@@ -1074,13 +1074,18 @@ def disable_lm_head(model: torch.nn.Module):
10741074
"""
10751075
_, lm_head = get_embeddings(model)
10761076
if lm_head is not None:
1077-
if has_offloaded_params(lm_head):
1078-
# keep weight on meta device
1079-
with patch_attr(lm_head._hf_hook, "offload", False):
1080-
yield
1081-
else:
1082-
with patch_attr(lm_head, "weight", lm_head.weight.to("meta")):
1083-
yield
1077+
if not isinstance(lm_head, torch.nn.Linear):
1078+
raise NotImplementedError(
1079+
f"Cannot disable LM head of type {lm_head.__class__.__name__}"
1080+
)
1081+
1082+
dummy_weight = lm_head.weight.to("meta")
1083+
1084+
def dummy_forward(self, input: torch.Tensor) -> torch.Tensor:
1085+
return input.to("meta") @ dummy_weight.T
1086+
1087+
with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)):
1088+
yield
10841089

10851090
else:
10861091
logger.warning(
@@ -1090,22 +1095,6 @@ def disable_lm_head(model: torch.nn.Module):
10901095
yield
10911096

10921097

1093-
def requires_lm_head_calibration(
1094-
model: PreTrainedModel, modifiers: Iterable["Modifier"]
1095-
) -> bool:
1096-
"""Returns True if any of the quantization modifers target the lm_head"""
1097-
from llmcompressor.modifiers.quantization.quantization.mixin import (
1098-
QuantizationMixin,
1099-
)
1100-
1101-
targets = set()
1102-
for mod in modifiers:
1103-
if isinstance(mod, QuantizationMixin):
1104-
targets |= set(match_named_modules(model, mod.resolved_targets, mod.ignore))
1105-
1106-
return targets_embeddings(model, targets, check_input=False, check_output=True)
1107-
1108-
11091098
@contextlib.contextmanager
11101099
def patch_attr(base: object, attr: str, value: Any):
11111100
"""

tests/llmcompressor/utils/test_helpers.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,23 @@
55
from transformers import (
66
AutoModelForCausalLM,
77
MllamaForConditionalGeneration,
8-
PretrainedConfig,
9-
PreTrainedModel,
108
)
119

10+
from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential
1211
from llmcompressor.utils import (
1312
ALL_TOKEN,
1413
DisableQuantization,
1514
calibration_forward_context,
1615
convert_to_bool,
1716
disable_cache,
17+
disable_lm_head,
1818
flatten_iterable,
1919
getattr_chain,
2020
interpolate,
2121
patch_attr,
2222
validate_str_iterable,
2323
)
24-
from llmcompressor.utils.dev import skip_weights_download
24+
from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download
2525
from tests.testing_utils import requires_gpu
2626

2727

@@ -149,20 +149,21 @@ def test_DisableQuantization():
149149

150150
@pytest.mark.unit
151151
def test_calibration_forward_context():
152-
class DummyModel(PreTrainedModel):
153-
config_class = PretrainedConfig
154-
155-
model = DummyModel(PretrainedConfig())
152+
with skip_weights_download():
153+
model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2")
156154
model.config.use_cache = True
157155
model.train()
158156

159157
with calibration_forward_context(model):
160158
assert not torch.is_grad_enabled()
161159
assert not model.config.use_cache
162160
assert not model.training
161+
assert model.lm_head.forward.__name__ == "dummy_forward"
162+
163163
assert torch.is_grad_enabled()
164164
assert model.config.use_cache
165165
assert model.training
166+
assert model.lm_head.forward.__name__ == "forward"
166167

167168

168169
@pytest.mark.unit
@@ -203,3 +204,29 @@ def test_disable_cache(model_cls, model_stub):
203204

204205
output = model(**inputs)
205206
assert output.past_key_values is not None
207+
208+
209+
@requires_gpu
210+
@pytest.mark.parametrize("offload", ["sequential", "basic", "none"])
211+
def test_disable_lm_head(offload):
212+
model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2")
213+
if offload == "sequential":
214+
dispatch_for_sequential(model)
215+
if offload == "basic":
216+
dispatch_for_generation(model)
217+
if offload == "none":
218+
model = model.to("cuda")
219+
220+
lm_input_device = None
221+
222+
def hook(module, args):
223+
nonlocal lm_input_device
224+
lm_input_device = args[0].device
225+
226+
model.lm_head.register_forward_pre_hook(hook)
227+
228+
with disable_lm_head(model):
229+
input = {key: value.to("cuda") for key, value in model.dummy_inputs.items()}
230+
output = model(**input)
231+
assert lm_input_device == torch.device("cuda:0")
232+
assert output.logits.device == torch.device("meta")

0 commit comments

Comments
 (0)