-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathcalibration.py
More file actions
316 lines (253 loc) · 10.7 KB
/
calibration.py
File metadata and controls
316 lines (253 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
from typing import Any, Optional
import torch
from compressed_tensors.quantization import (
DynamicType,
QuantizationArgs,
QuantizationStatus,
QuantizationStrategy,
)
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
from compressed_tensors.utils import (
align_module_device,
getattr_chain,
update_offload_parameter,
)
from loguru import logger
from torch.nn import Module
from llmcompressor.observers import Observer
from llmcompressor.observers.base import update_module_qparams_from_observer
__all__ = [
"initialize_observer",
"update_weight_zp_scale",
"calibrate_input_hook",
"calibrate_output_hook",
"freeze_module_quantization",
"apply_calibration_status",
"reset_quantization_status",
"update_weight_global_scale",
"calibrate_query_hook",
"calibrate_key_hook",
"calibrate_value_hook",
"write_activation_qparams",
]
# Activation observer base names used across calibration and quantization code
ACTIVATION_BASE_NAMES = ("input", "output", "q", "k", "v")
def initialize_observer(
module: Module,
base_name: str,
):
"""
Initialize observer module and attach as submodule.
The name of the observer is fetched from the quantization_args.
The name is then used to load the observer from the registry and attached
to the module. The name of the observer uses the base_name provided.
This function always initializes memoryless observers for weights
:param module: torch.nn.Module that the observer is being attached to
:param base_name: str used to name the observer attribute
"""
if base_name == "weight":
arg_name = "weights"
elif base_name == "output":
arg_name = "output_activations"
else: # input, q, k, v
arg_name = "input_activations"
args: QuantizationArgs = getattr_chain(
module, f"quantization_scheme.{arg_name}", None
)
observer = args.observer
# training is no longer supported: always use memoryless for weights
if base_name == "weight" and args.observer in ("static_minmax", "minmax"):
observer = "memoryless_minmax"
logger.warning(
"Overriding weight observer for lower memory usage "
f"({args.observer} -> {observer})",
log_once=True,
)
if base_name == "weight" and args.observer in ("mse",):
observer = "memoryless_mse"
logger.warning(
"Overriding weight observer for lower memory usage "
f"({args.observer} -> {observer})",
log_once=True,
)
if args is not None and args.dynamic is not True:
observer = Observer.load_from_registry(
observer, base_name=base_name, args=args, module=module
)
module.register_module(f"{base_name}_observer", observer)
def call_observer(
module: Module,
base_name: str,
value: Optional[torch.Tensor] = None,
should_calculate_gparam: bool = False,
should_calculate_qparams: bool = True,
):
"""
Call a module's attached input/weight/output observer using a provided value.
Update the module's scale and zp using the observer's return values.
:param module: torch.nn.Module
:param base_name: substring used to fetch the observer, scales, and zp
:param value: torch.Tensor to be passed to the observer for activations. If
base_name is "weight", then the module's weight tensor will be used
"""
with align_module_device(module):
if value is None and base_name == "weight":
value = module.weight
observer: Observer = getattr(module, f"{base_name}_observer")
if should_calculate_gparam:
global_scale = observer.get_global_scale(value)
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
if should_calculate_qparams:
scale, zero_point = observer(value)
update_offload_parameter(module, f"{base_name}_scale", scale)
if hasattr(module, f"{base_name}_zero_point"):
update_offload_parameter(module, f"{base_name}_zero_point", zero_point)
def update_weight_global_scale(module: Module):
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return
if (
getattr_chain(module, "quantization_scheme.weights.strategy", None)
!= QuantizationStrategy.TENSOR_GROUP
):
return
call_observer(
module,
base_name="weight",
should_calculate_gparam=True,
should_calculate_qparams=False,
)
def update_weight_zp_scale(module: Module):
"""
marks a layer as ready for calibration which activates observers
to update scales and zero points on each forward pass
apply to full model with `model.apply(update_weight_zp_scale)`
:param module: module to set for calibration
:param quantize_weights_upfront: whether to automatically
run weight quantization at the start of calibration
"""
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return
if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
logger.warning(
"Attempting to calibrate weights of a module not in calibration mode"
)
call_observer(module=module, base_name="weight")
def calibrate_activations(
module: Module, value: torch.Tensor, base_name: str, stats_only: bool = False
):
"""
Calibrate input or output activations by calling the a module's attached
observer.
:param module: torch.nn.Module
:param base_name: substring used to fetch the observer, scales, and zp
:param value: torch.Tensor to be passed to the observer
:param stats_only: if True, only update running statistics in the observer
(accumulate min/max) without computing or writing scale/zero_point.
Used during deferred qparam calibration — qparams are computed once
at epoch end via write_activation_qparams instead of per batch.
"""
# If empty tensor, can't update zp/scale
# Case for MoEs
if value.numel() == 0:
return
field_name = "input" if base_name != "output" else "output" # input,q,k,v,output
args_attr = f"quantization_scheme.{field_name}_activations"
quantization_args = getattr_chain(module, args_attr, None)
calculate_qparams = True
calculate_gparam = False
if quantization_args is not None:
if quantization_args.dynamic in (True, DynamicType.LOCAL):
calculate_qparams = False
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
calculate_gparam = True
# In deferred (stats_only) mode: call the observer to accumulate running
# min/max stats but do NOT write scale/zero_point yet.
# Qparams are written once at epoch end via write_activation_qparams.
if stats_only:
observer = getattr(module, f"{base_name}_observer", None)
if observer is not None:
observer.update_deferred_stats(value)
return
call_observer(
module=module,
base_name=base_name,
value=value,
should_calculate_gparam=calculate_gparam,
should_calculate_qparams=calculate_qparams,
)
def calibrate_input_hook(module: Module, args: Any):
"""
Hook to accumulate input activation statistics (min/max) in the observer.
Scale and zero_point are not written here; they are computed once per subgraph
at epoch end via write_activation_qparams.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input", stats_only=True)
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
"""
Hook to accumulate output activation statistics (min/max) in the observer.
Scale and zero_point are not written here; they are computed once per subgraph
at epoch end via write_activation_qparams.
Note: forward_quantize is intentionally absent — hooks only collect statistics.
"""
calibrate_activations(
module,
value=output,
base_name="output",
stats_only=True,
)
return output
def calibrate_query_hook(module: Module, query_states: torch.Tensor):
calibrate_activations(module, query_states, base_name="q", stats_only=True)
def calibrate_key_hook(module: Module, key_states: torch.Tensor):
calibrate_activations(module, key_states, base_name="k", stats_only=True)
def calibrate_value_hook(module: Module, value_states: torch.Tensor):
calibrate_activations(module, value_states, base_name="v", stats_only=True)
def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return
module.quantization_status = QuantizationStatus.CALIBRATION
def freeze_module_quantization(module: Module):
"""
deletes observers when calibration is complete.
apply to full model with `model.apply(freeze_module_quantization)`
:param module: module to freeze quantization for
"""
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return
if module.quantization_status == QuantizationStatus.FROZEN:
# nothing to do, already frozen
return
# remove observers
for name in ("input", "weight", "output", "q", "k", "v"):
obs_name = f"{name}_observer"
if hasattr(module, obs_name):
delattr(module, obs_name)
module.quantization_status = QuantizationStatus.FROZEN
def reset_quantization_status(model: Module):
for module in model.modules():
if hasattr(module, "quantization_status"):
delattr(module, "quantization_status")
def write_activation_qparams(module: Module):
"""
Compute and write final activation qparams from each observer's accumulated
running statistics, then free those statistics to reduce memory.
This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the
per-batch qparam updates that were previously triggered by calibration hooks.
It is a no-op for modules with no quantization scheme or no activation observers.
Note: weight observers are not touched here — weight qparams are always computed
up-front in ``on_start`` via ``update_weight_zp_scale``.
apply to targeted modules with:
for _, module in match_named_modules(...):
write_activation_qparams(module)
:param module: module to flush activation qparams for
"""
scheme = getattr(module, "quantization_scheme", None)
if scheme is None:
return
for base_name in ACTIVATION_BASE_NAMES:
update_module_qparams_from_observer(module, base_name)