1717from torch .nn import Module
1818
1919from llmcompressor .observers import Observer
20+ from llmcompressor .observers .base import calibrate_module_from_observer
2021
2122__all__ = [
2223 "initialize_observer" ,
3031 "calibrate_query_hook" ,
3132 "calibrate_key_hook" ,
3233 "calibrate_value_hook" ,
34+ "flush_activation_qparams" ,
3335]
3436
3537
@@ -156,15 +158,20 @@ def update_weight_zp_scale(module: Module):
156158 call_observer (module = module , base_name = "weight" )
157159
158160
159- def calibrate_activations (module : Module , value : torch .Tensor , base_name : str ):
161+ def calibrate_activations (
162+ module : Module , value : torch .Tensor , base_name : str , stats_only : bool = False
163+ ):
160164 """
161165 Calibrate input or output activations by calling the a module's attached
162166 observer.
163167
164168 :param module: torch.nn.Module
165169 :param base_name: substring used to fetch the observer, scales, and zp
166170 :param value: torch.Tensor to be passed to the observer
167-
171+ :param stats_only: if True, only update running statistics in the observer
172+ (accumulate min/max) without computing or writing scale/zero_point.
173+ Used during deferred qparam calibration — qparams are computed once
174+ at epoch end via flush_activation_qparams instead of per batch.
168175 """
169176 # If empty tensor, can't update zp/scale
170177 # Case for MoEs
@@ -184,6 +191,12 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
184191 if quantization_args .strategy == QuantizationStrategy .TENSOR_GROUP :
185192 calculate_gparam = True
186193
194+ # In deferred (stats_only) mode, only accumulate running min/max in the
195+ # observer — skip writing scale/zero_point until epoch end.
196+ if stats_only :
197+ calculate_qparams = False
198+ calculate_gparam = False
199+
187200 call_observer (
188201 module = module ,
189202 base_name = base_name ,
@@ -196,43 +209,40 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
196209def calibrate_input_hook (module : Module , args : Any ):
197210 """
198211 Hook to calibrate input activations.
199- Will call the observers to update the scales/zp before applying
200- input QDQ in the module's forward pass.
212+ Accumulates running min/max statistics in the observer without computing
213+ scale/zero_point. Qparams are computed once at epoch end via
214+ flush_activation_qparams (deferred mode).
201215 """
202216 args = args [0 ] if isinstance (args , tuple ) else args
203- calibrate_activations (module , value = args , base_name = "input" )
217+ calibrate_activations (module , value = args , base_name = "input" , stats_only = True )
204218
205219
206220def calibrate_output_hook (module : Module , _args : Any , output : torch .Tensor ):
207221 """
208222 Hook to calibrate output activations.
209- Will call the observers to update the scales/zp before applying
210- output QDQ.
223+ Accumulates running min/max statistics only (deferred qparam mode).
224+ Qparams are computed at epoch end; forward_quantize is skipped during
225+ calibration batches since quantization is disabled in the sequential pipeline.
211226 """
212227 calibrate_activations (
213228 module ,
214229 value = output ,
215230 base_name = "output" ,
216- )
217- output = forward_quantize (
218- module = module ,
219- value = output ,
220- base_name = "output" ,
221- args = module .quantization_scheme .output_activations ,
231+ stats_only = True ,
222232 )
223233 return output
224234
225235
226236def calibrate_query_hook (module : Module , query_states : torch .Tensor ):
227- calibrate_activations (module , query_states , base_name = "q" )
237+ calibrate_activations (module , query_states , base_name = "q" , stats_only = True )
228238
229239
230240def calibrate_key_hook (module : Module , key_states : torch .Tensor ):
231- calibrate_activations (module , key_states , base_name = "k" )
241+ calibrate_activations (module , key_states , base_name = "k" , stats_only = True )
232242
233243
234244def calibrate_value_hook (module : Module , value_states : torch .Tensor ):
235- calibrate_activations (module , value_states , base_name = "v" )
245+ calibrate_activations (module , value_states , base_name = "v" , stats_only = True )
236246
237247
238248def apply_calibration_status (module : Module ):
@@ -273,3 +283,29 @@ def reset_quantization_status(model: Module):
273283 for module in model .modules ():
274284 if hasattr (module , "quantization_status" ):
275285 delattr (module , "quantization_status" )
286+
287+
288+ def flush_activation_qparams (module : Module ):
289+ """
290+ Compute and write final activation qparams from each observer's accumulated
291+ running statistics, then free those statistics to reduce memory.
292+
293+ This is called once at SEQUENTIAL_EPOCH_END for each subgraph, replacing the
294+ per-batch qparam updates that were previously triggered by calibration hooks.
295+ It is a no-op for modules with no quantization scheme or no activation observers.
296+
297+ Note: weight observers are not touched here — weight qparams are always computed
298+ up-front in ``on_start`` via ``update_weight_zp_scale``.
299+
300+ apply to targeted modules with:
301+ for _, module in match_named_modules(...):
302+ flush_activation_qparams(module)
303+
304+ :param module: module to flush activation qparams for
305+ """
306+ scheme = getattr (module , "quantization_scheme" , None )
307+ if scheme is None :
308+ return
309+
310+ for base_name in ("input" , "output" , "q" , "k" , "v" ):
311+ calibrate_module_from_observer (module , base_name )
0 commit comments