|
18 | 18 | from collections import OrderedDict |
19 | 19 | from io import BytesIO |
20 | 20 | from pathlib import Path |
21 | | -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union |
| 21 | +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union |
22 | 22 | from urllib.parse import urlparse |
23 | 23 |
|
24 | 24 | import numpy |
25 | 25 | import torch |
| 26 | +from compressed_tensors import has_offloaded_params, match_named_modules |
26 | 27 | from compressed_tensors.quantization import disable_quantization, enable_quantization |
27 | 28 | from loguru import logger |
28 | 29 | from transformers import PreTrainedModel |
29 | 30 |
|
| 31 | +from llmcompressor.utils import get_embeddings, targets_embeddings |
| 32 | + |
| 33 | +if TYPE_CHECKING: |
| 34 | + from llmcompressor.modifiers import Modifier |
| 35 | + |
30 | 36 | __all__ = [ |
31 | 37 | "ALL_TOKEN", |
32 | 38 | "ALL_PRUNABLE_TOKEN", |
|
65 | 71 | "DisableQuantization", |
66 | 72 | "eval_context", |
67 | 73 | "calibration_forward_context", |
| 74 | + "disable_lm_head", |
| 75 | + "requires_lm_head_calibration", |
68 | 76 | "patch_attr", |
69 | 77 | "disable_hf_kernels", |
70 | 78 | "DISABLE_QAC_MODIFIERS", |
@@ -1050,12 +1058,54 @@ def calibration_forward_context(model: torch.nn.Module): |
1050 | 1058 | - Disable train mode and enable eval mode |
1051 | 1059 | - Disable hf kernels which could bypass hooks |
1052 | 1060 | """ |
1053 | | - with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels( |
1054 | | - model |
1055 | | - ): |
| 1061 | + with contextlib.ExitStack() as stack: |
| 1062 | + stack.enter_context(torch.no_grad()) |
| 1063 | + stack.enter_context(disable_cache(model)) |
| 1064 | + stack.enter_context(eval_context(model)) |
| 1065 | + stack.enter_context(disable_hf_kernels(model)) |
1056 | 1066 | yield |
1057 | 1067 |
|
1058 | 1068 |
|
| 1069 | +@contextlib.contextmanager |
| 1070 | +def disable_lm_head(model: torch.nn.Module): |
| 1071 | + """ |
| 1072 | + Disable the lm_head of a model by moving it to the meta device. This function |
| 1073 | + does not untie parameters and restores the model proper loading upon exit |
| 1074 | + """ |
| 1075 | + _, lm_head = get_embeddings(model) |
| 1076 | + if lm_head is not None: |
| 1077 | + if has_offloaded_params(lm_head): |
| 1078 | + # keep weight on meta device |
| 1079 | + with patch_attr(lm_head._hf_hook, "offload", False): |
| 1080 | + yield |
| 1081 | + else: |
| 1082 | + with patch_attr(lm_head, "weight", lm_head.weight.to("meta")): |
| 1083 | + yield |
| 1084 | + |
| 1085 | + else: |
| 1086 | + logger.warning( |
| 1087 | + f"Attempted to disable lm_head of instance {model.__class__.__name__}, " |
| 1088 | + "but was unable to to find lm_head. This may lead to unexpected OOM." |
| 1089 | + ) |
| 1090 | + yield |
| 1091 | + |
| 1092 | + |
| 1093 | +def requires_lm_head_calibration( |
| 1094 | + model: PreTrainedModel, modifiers: Iterable["Modifier"] |
| 1095 | +) -> bool: |
| 1096 | + """Returns True if any of the quantization modifers target the lm_head""" |
| 1097 | + from llmcompressor.modifiers.quantization.quantization.mixin import ( |
| 1098 | + QuantizationMixin, |
| 1099 | + ) |
| 1100 | + |
| 1101 | + targets = set() |
| 1102 | + for mod in modifiers: |
| 1103 | + if isinstance(mod, QuantizationMixin): |
| 1104 | + targets |= set(match_named_modules(model, mod.resolved_targets, mod.ignore)) |
| 1105 | + |
| 1106 | + return targets_embeddings(model, targets, check_input=True, check_output=False) |
| 1107 | + |
| 1108 | + |
1059 | 1109 | @contextlib.contextmanager |
1060 | 1110 | def patch_attr(base: object, attr: str, value: Any): |
1061 | 1111 | """ |
|
0 commit comments