|
18 | 18 | from collections import OrderedDict |
19 | 19 | from io import BytesIO |
20 | 20 | from pathlib import Path |
21 | | -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union |
| 21 | +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, TYPE_CHECKING |
22 | 22 | from urllib.parse import urlparse |
23 | 23 |
|
24 | 24 | import numpy |
|
28 | 28 | from loguru import logger |
29 | 29 | from transformers import PreTrainedModel |
30 | 30 |
|
| 31 | +if TYPE_CHECKING: |
| 32 | + from llmcompressor.modifiers.modifier import Modifier |
| 33 | + |
31 | 34 | __all__ = [ |
32 | 35 | "ALL_TOKEN", |
33 | 36 | "ALL_PRUNABLE_TOKEN", |
|
68 | 71 | "calibration_forward_context", |
69 | 72 | "patch_attr", |
70 | 73 | "disable_hf_kernels", |
| 74 | + "disable_lm_head", |
| 75 | + "targets_lm_head", |
71 | 76 | "DISABLE_QAC_MODIFIERS", |
72 | 77 | ] |
73 | 78 |
|
@@ -1042,23 +1047,20 @@ def disable_hf_kernels(module: torch.nn.Module): |
1042 | 1047 |
|
1043 | 1048 |
|
1044 | 1049 | @contextlib.contextmanager |
1045 | | -def calibration_forward_context(model: torch.nn.Module, skip_lm_head: bool = False): |
| 1050 | +def calibration_forward_context(model: torch.nn.Module): |
1046 | 1051 | """ |
1047 | 1052 | Context in which all calibration forward passes should occur. |
1048 | 1053 |
|
1049 | 1054 | - Remove gradient calculations |
1050 | 1055 | - Disable the KV cache |
1051 | 1056 | - Disable train mode and enable eval mode |
1052 | 1057 | - Disable hf kernels which could bypass hooks |
1053 | | - - Disable lm_head of model (optional) |
1054 | 1058 | """ |
1055 | 1059 | with contextlib.ExitStack() as stack: |
1056 | 1060 | stack.enter_context(torch.no_grad()) |
1057 | 1061 | stack.enter_context(disable_cache(model)) |
1058 | 1062 | stack.enter_context(eval_context(model)) |
1059 | 1063 | stack.enter_context(disable_hf_kernels(model)) |
1060 | | - if skip_lm_head: |
1061 | | - stack.enter_context(disable_lm_head(model)) |
1062 | 1064 |
|
1063 | 1065 | yield |
1064 | 1066 |
|
@@ -1091,7 +1093,19 @@ def disable_lm_head(model: torch.nn.Module): |
1091 | 1093 | yield |
1092 | 1094 |
|
1093 | 1095 |
|
1094 | | -# TODO: deprecate |
| 1096 | +def targets_lm_head(model: PreTrainedModel, modifiers: list["Modifier"]) -> bool: |
| 1097 | + """ Returns True if the given modifiers target the lm_head """ |
| 1098 | + from llmcompressor.transformers.compression.compressed_tensors_utils import ( |
| 1099 | + targets_embeddings |
| 1100 | + ) |
| 1101 | + |
| 1102 | + targets = sum( |
| 1103 | + (list(modifier.get_targets(model)) for modifier in modifiers), start=[] |
| 1104 | + ) |
| 1105 | + return targets_embeddings(model, targets, check_input=True, check_output=False) |
| 1106 | + |
| 1107 | + |
| 1108 | + |
1095 | 1109 | @contextlib.contextmanager |
1096 | 1110 | def patch_attr(base: object, attr: str, value: Any): |
1097 | 1111 | """ |
|
0 commit comments