Skip to content

Commit 49e7208

Browse files
committed
feat: defer activation qparam calculation to sequential epoch end
Fixes #2446
1 parent a2433a9 commit 49e7208

File tree

3 files changed

+153
-18
lines changed

3 files changed

+153
-18
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.nn import Module
1818

1919
from llmcompressor.observers import Observer
20+
from llmcompressor.observers.base import calibrate_module_from_observer
2021

2122
__all__ = [
2223
"initialize_observer",
@@ -30,6 +31,7 @@
3031
"calibrate_query_hook",
3132
"calibrate_key_hook",
3233
"calibrate_value_hook",
34+
"flush_activation_qparams",
3335
]
3436

3537

@@ -156,15 +158,20 @@ def update_weight_zp_scale(module: Module):
156158
call_observer(module=module, base_name="weight")
157159

158160

159-
def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
161+
def calibrate_activations(
162+
module: Module, value: torch.Tensor, base_name: str, stats_only: bool = False
163+
):
160164
"""
161165
Calibrate input or output activations by calling the a module's attached
162166
observer.
163167
164168
:param module: torch.nn.Module
165169
:param base_name: substring used to fetch the observer, scales, and zp
166170
:param value: torch.Tensor to be passed to the observer
167-
171+
:param stats_only: if True, only update running statistics in the observer
172+
(accumulate min/max) without computing or writing scale/zero_point.
173+
Used during deferred qparam calibration — qparams are computed once
174+
at epoch end via flush_activation_qparams instead of per batch.
168175
"""
169176
# If empty tensor, can't update zp/scale
170177
# Case for MoEs
@@ -184,6 +191,12 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
184191
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
185192
calculate_gparam = True
186193

194+
# In deferred (stats_only) mode, only accumulate running min/max in the
195+
# observer — skip writing scale/zero_point until epoch end.
196+
if stats_only:
197+
calculate_qparams = False
198+
calculate_gparam = False
199+
187200
call_observer(
188201
module=module,
189202
base_name=base_name,
@@ -196,43 +209,40 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
196209
def calibrate_input_hook(module: Module, args: Any):
197210
"""
198211
Hook to calibrate input activations.
199-
Will call the observers to update the scales/zp before applying
200-
input QDQ in the module's forward pass.
212+
Accumulates running min/max statistics in the observer without computing
213+
scale/zero_point. Qparams are computed once at epoch end via
214+
flush_activation_qparams (deferred mode).
201215
"""
202216
args = args[0] if isinstance(args, tuple) else args
203-
calibrate_activations(module, value=args, base_name="input")
217+
calibrate_activations(module, value=args, base_name="input", stats_only=True)
204218

205219

206220
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
207221
"""
208222
Hook to calibrate output activations.
209-
Will call the observers to update the scales/zp before applying
210-
output QDQ.
223+
Accumulates running min/max statistics only (deferred qparam mode).
224+
Qparams are computed at epoch end; forward_quantize is skipped during
225+
calibration batches since quantization is disabled in the sequential pipeline.
211226
"""
212227
calibrate_activations(
213228
module,
214229
value=output,
215230
base_name="output",
216-
)
217-
output = forward_quantize(
218-
module=module,
219-
value=output,
220-
base_name="output",
221-
args=module.quantization_scheme.output_activations,
231+
stats_only=True,
222232
)
223233
return output
224234

225235

226236
def calibrate_query_hook(module: Module, query_states: torch.Tensor):
227-
calibrate_activations(module, query_states, base_name="q")
237+
calibrate_activations(module, query_states, base_name="q", stats_only=True)
228238

229239

230240
def calibrate_key_hook(module: Module, key_states: torch.Tensor):
231-
calibrate_activations(module, key_states, base_name="k")
241+
calibrate_activations(module, key_states, base_name="k", stats_only=True)
232242

233243

234244
def calibrate_value_hook(module: Module, value_states: torch.Tensor):
235-
calibrate_activations(module, value_states, base_name="v")
245+
calibrate_activations(module, value_states, base_name="v", stats_only=True)
236246

237247

238248
def apply_calibration_status(module: Module):
@@ -273,3 +283,29 @@ def reset_quantization_status(model: Module):
273283
for module in model.modules():
274284
if hasattr(module, "quantization_status"):
275285
delattr(module, "quantization_status")
286+
287+
288+
def flush_activation_qparams(module: Module):
289+
"""
290+
Compute and write final activation qparams from each observer's accumulated
291+
running statistics, then free those statistics to reduce memory.
292+
293+
This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the
294+
per-batch qparam updates that were previously triggered by calibration hooks.
295+
It is a no-op for modules with no quantization scheme or no activation observers.
296+
297+
Note: weight observers are not touched here — weight qparams are always computed
298+
up-front in ``on_start`` via ``update_weight_zp_scale``.
299+
300+
apply to targeted modules with:
301+
for _, module in match_named_modules(...):
302+
flush_activation_qparams(module)
303+
304+
:param module: module to flush activation qparams for
305+
"""
306+
scheme = getattr(module, "quantization_scheme", None)
307+
if scheme is None:
308+
return
309+
310+
for base_name in ("input", "output", "q", "k", "v"):
311+
calibrate_module_from_observer(module, base_name)

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from llmcompressor.core import Event, EventType, State
55
from llmcompressor.modifiers import Modifier
66
from llmcompressor.modifiers.quantization.calibration import (
7+
flush_activation_qparams,
78
update_weight_global_scale,
89
update_weight_zp_scale,
910
)
@@ -65,7 +66,10 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6566

6667
def on_start(self, state: State, event: Event, **kwargs):
6768
"""
68-
Begin calibrating activations and weights. Calibrate weights only once on start
69+
Begin calibrating activations and weights. Calibrate weights only once on start.
70+
Quantization is kept DISABLED during calibration batches so that forward passes
71+
run in fp32. Activation qparams are computed once per subgraph at
72+
SEQUENTIAL_EPOCH_END via flush_activation_qparams (deferred mode).
6973
"""
7074
self.started_ = True
7175
QuantizationMixin.start_calibration(self, state.model)
@@ -90,11 +94,26 @@ def on_start(self, state: State, event: Event, **kwargs):
9094
for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
9195
update_weight_zp_scale(module)
9296

97+
# Disable quantization during calibration batches so that fp32 activations
98+
# flow through the model unmodified while hooks accumulate running stats.
99+
# Re-enable once after epoch end when qparams have been flushed.
100+
from compressed_tensors.quantization import disable_quantization
101+
102+
state.model.apply(disable_quantization)
103+
93104
def on_event(self, state: State, event: Event, **kwargs):
94105
if event.type_ == EventType.CALIBRATION_EPOCH_START:
95106
if not self.started_:
96107
self.on_start(state, None)
97108

109+
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
110+
# Deferred qparam flush: compute scale/zero_point from accumulated
111+
# running statistics, then free those stats to reduce memory.
112+
for _, module in match_named_modules(
113+
state.model, self.resolved_targets, self.ignore
114+
):
115+
flush_activation_qparams(module)
116+
98117
if event.type_ == EventType.CALIBRATION_EPOCH_END:
99118
if not self.ended_:
100119
self.on_end(state, None)

src/llmcompressor/observers/base.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from llmcompressor.observers.helpers import flatten_for_calibration
1313

14-
__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple"]
14+
__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"]
1515

1616
MinMaxTuple = Tuple[torch.Tensor, torch.Tensor]
1717
ScaleZpTuple = Tuple[torch.Tensor, torch.Tensor]
@@ -74,6 +74,39 @@ def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
7474
"""
7575
raise NotImplementedError()
7676

77+
def get_accumulated_min_max(self) -> Optional[MinMaxTuple]:
78+
"""
79+
Return the accumulated running min/max statistics stored by this observer,
80+
without performing any new observation. Returns None if no statistics have
81+
been accumulated yet (i.e. no batches have been seen).
82+
83+
Subclasses which accumulate state (StaticMinMax, MovingAverage) naturally
84+
expose this through their ``past_min_vals`` / ``past_max_vals`` attributes.
85+
Memoryless observers have no running state, so this always returns None.
86+
87+
:return: (min_vals, max_vals) tensors or None
88+
"""
89+
min_vals = getattr(self, "past_min_vals", None)
90+
max_vals = getattr(self, "past_max_vals", None)
91+
if min_vals is None or max_vals is None:
92+
return None
93+
return min_vals, max_vals
94+
95+
def clear_accumulated_stats(self):
96+
"""
97+
Delete accumulated running statistics to free memory after qparams have been
98+
computed and written to the parent module. Only clears attributes that exist
99+
on the observer (memoryless observers are unaffected).
100+
"""
101+
for attr in (
102+
"past_min_vals",
103+
"past_max_vals",
104+
"past_global_min_vals",
105+
"past_global_max_vals",
106+
):
107+
if hasattr(self, attr):
108+
delattr(self, attr)
109+
77110
@torch.no_grad
78111
def forward(self, observed: torch.Tensor) -> ScaleZpTuple:
79112
"""
@@ -142,3 +175,50 @@ def _check_has_global_scale(self, global_scale: Optional[torch.nn.Parameter]):
142175
"Cannot compute scale and zero points "
143176
"without first computing global scale"
144177
)
178+
179+
180+
@torch.no_grad()
181+
def calibrate_module_from_observer(
182+
module: torch.nn.Module,
183+
base_name: str,
184+
) -> bool:
185+
"""
186+
Flush an observer's accumulated running statistics into the parent module's
187+
quantization parameters (scale / zero_point), then free the running stats.
188+
189+
This is the deferred counterpart to ``call_observer``. Instead of accepting a
190+
fresh activation tensor, it reads the min/max values that the observer has
191+
already accumulated across all calibration batches and computes qparams from
192+
those final statistics.
193+
194+
:param module: module whose ``{base_name}_observer`` attribute holds the observer
195+
:param base_name: one of "input", "output", "q", "k", "v"
196+
:return: True if qparams were updated, False if observer had no accumulated stats
197+
"""
198+
from compressed_tensors.utils import align_module_device, update_offload_parameter
199+
200+
observer: Optional[Observer] = getattr(module, f"{base_name}_observer", None)
201+
if observer is None:
202+
return False
203+
204+
accumulated = observer.get_accumulated_min_max()
205+
if accumulated is None:
206+
return False
207+
208+
min_vals, max_vals = accumulated
209+
global_scale = getattr(module, f"{base_name}_global_scale", None)
210+
211+
with align_module_device(module):
212+
scales, zero_points = calculate_qparams(
213+
min_vals=min_vals,
214+
max_vals=max_vals,
215+
quantization_args=observer.args,
216+
global_scale=global_scale,
217+
)
218+
update_offload_parameter(module, f"{base_name}_scale", scales)
219+
if hasattr(module, f"{base_name}_zero_point"):
220+
update_offload_parameter(module, f"{base_name}_zero_point", zero_points)
221+
222+
# Free memory — running stats no longer needed
223+
observer.clear_accumulated_stats()
224+
return True

0 commit comments

Comments
 (0)