Skip to content

Commit e3e009f

Browse files
committed
implement requires_lm_head_calibration, disable_lm_head
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 90d1cac commit e3e009f

File tree

3 files changed

+73
-6
lines changed

3 files changed

+73
-6
lines changed

src/llmcompressor/pipelines/basic/pipeline.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
from compressed_tensors.utils import get_execution_device
77
from torch.utils.data.dataloader import DataLoader
88

9-
from llmcompressor.core import LifecycleCallbacks
9+
from llmcompressor.core import LifecycleCallbacks, active_session
1010
from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
1111
from llmcompressor.pipelines.registry import CalibrationPipeline
1212
from llmcompressor.pytorch.utils.helpers import tensors_to_device
13-
from llmcompressor.utils import calibration_forward_context, dispatch_for_generation
13+
from llmcompressor.utils import (
14+
calibration_forward_context,
15+
disable_lm_head,
16+
dispatch_for_generation,
17+
requires_lm_head_calibration,
18+
)
1419

1520
if TYPE_CHECKING:
1621
from llmcompressor.args.dataset_arguments import DatasetArguments
@@ -38,13 +43,20 @@ def __call__(
3843
:param dataloader: loads data for calibration
3944
:param dataset_args: dataset arguments relevant to pipelines
4045
"""
46+
session = active_session()
47+
modifiers = session.lifecycle.recipe.modifiers
48+
4149
dispatch_for_generation(model) # basic dispatch is identical to generation
4250
model_device = get_execution_device(model)
4351

4452
LifecycleCallbacks.calibration_epoch_start()
4553

4654
with contextlib.ExitStack() as stack:
4755
stack.enter_context(calibration_forward_context(model))
56+
# Optionally disable lm_head
57+
if not requires_lm_head_calibration(model, modifiers):
58+
stack.enter_context(disable_lm_head(model))
59+
4860
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
4961
batch = apply_pad_mask_to_batch(batch)
5062
batch = tensors_to_device(batch, model_device)

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
DISABLE_QAC_MODIFIERS,
2020
DisableQuantization,
2121
calibration_forward_context,
22+
disable_lm_head,
23+
requires_lm_head_calibration,
2224
)
2325

2426
if TYPE_CHECKING:
@@ -88,6 +90,9 @@ def __call__(
8890
# Optionally disable quantization
8991
if not dataset_args.quantization_aware_calibration or disable_qac:
9092
stack.enter_context(DisableQuantization(model))
93+
# Optionally disable lm_head
94+
if not requires_lm_head_calibration(model, modifiers):
95+
stack.enter_context(disable_lm_head(model))
9196

9297
# prepare intermediates cache
9398
activations = IntermediatesCache.from_dataloader(dataloader, model_device)

src/llmcompressor/utils/helpers.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@
1818
from collections import OrderedDict
1919
from io import BytesIO
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union
2222
from urllib.parse import urlparse
2323

2424
import numpy
2525
import torch
26+
from compressed_tensors import has_offloaded_params, match_named_modules
2627
from compressed_tensors.quantization import disable_quantization, enable_quantization
2728
from loguru import logger
2829
from transformers import PreTrainedModel
2930

31+
from llmcompressor.utils import get_embeddings, targets_embeddings
32+
33+
if TYPE_CHECKING:
34+
from llmcompressor.modifiers import Modifier
35+
3036
__all__ = [
3137
"ALL_TOKEN",
3238
"ALL_PRUNABLE_TOKEN",
@@ -65,6 +71,8 @@
6571
"DisableQuantization",
6672
"eval_context",
6773
"calibration_forward_context",
74+
"disable_lm_head",
75+
"requires_lm_head_calibration",
6876
"patch_attr",
6977
"disable_hf_kernels",
7078
"DISABLE_QAC_MODIFIERS",
@@ -1050,12 +1058,54 @@ def calibration_forward_context(model: torch.nn.Module):
10501058
- Disable train mode and enable eval mode
10511059
- Disable hf kernels which could bypass hooks
10521060
"""
1053-
with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels(
1054-
model
1055-
):
1061+
with contextlib.ExitStack() as stack:
1062+
stack.enter_context(torch.no_grad())
1063+
stack.enter_context(disable_cache(model))
1064+
stack.enter_context(eval_context(model))
1065+
stack.enter_context(disable_hf_kernels(model))
10561066
yield
10571067

10581068

1069+
@contextlib.contextmanager
1070+
def disable_lm_head(model: torch.nn.Module):
1071+
"""
1072+
Disable the lm_head of a model by moving it to the meta device. This function
1073+
does not untie parameters and restores the model proper loading upon exit
1074+
"""
1075+
_, lm_head = get_embeddings(model)
1076+
if lm_head is not None:
1077+
if has_offloaded_params(lm_head):
1078+
# keep weight on meta device
1079+
with patch_attr(lm_head._hf_hook, "offload", False):
1080+
yield
1081+
else:
1082+
with patch_attr(lm_head, "weight", lm_head.weight.to("meta")):
1083+
yield
1084+
1085+
else:
1086+
logger.warning(
1087+
f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
1088+
"but was unable to to find lm_head. This may lead to unexpected OOM."
1089+
)
1090+
yield
1091+
1092+
1093+
def requires_lm_head_calibration(
1094+
model: PreTrainedModel, modifiers: Iterable["Modifier"]
1095+
) -> bool:
1096+
"""Returns True if any of the quantization modifers target the lm_head"""
1097+
from llmcompressor.modifiers.quantization.quantization.mixin import (
1098+
QuantizationMixin,
1099+
)
1100+
1101+
targets = set()
1102+
for mod in modifiers:
1103+
if isinstance(mod, QuantizationMixin):
1104+
targets |= set(match_named_modules(model, mod.resolved_targets, mod.ignore))
1105+
1106+
return targets_embeddings(model, targets, check_input=True, check_output=False)
1107+
1108+
10591109
@contextlib.contextmanager
10601110
def patch_attr(base: object, attr: str, value: Any):
10611111
"""

0 commit comments

Comments
 (0)