1- from typing import Any , Iterable , Optional , Set , Tuple
1+ import threading
2+ from concurrent .futures import ThreadPoolExecutor
3+ from typing import Any , Iterable , Iterator , Optional , Tuple
24
35import torch
6+ import tqdm
47from compressed_tensors .quantization import (
58 DynamicType ,
69 QuantizationArgs ,
1114from compressed_tensors .utils import (
1215 align_module_device ,
1316 getattr_chain ,
17+ match_named_modules ,
1418 update_offload_parameter ,
1519)
1620from loguru import logger
@@ -136,6 +140,52 @@ def update_weight_global_scale(module: Module):
136140 )
137141
138142
143+ def _post_order_modules (model : Module ) -> Iterator [Module ]:
144+ """Yield every module in the tree in DFS post-order."""
145+ stack : list [Tuple [Module , bool ]] = [(model , False )]
146+ while stack :
147+ module , children_done = stack .pop ()
148+ if not children_done :
149+ stack .append ((module , True ))
150+ for child in reversed (list (module .children ())):
151+ stack .append ((child , False ))
152+ else :
153+ yield module
154+
155+
156+ def _update_weight_calibration_once (module : Module , update_zp_scale : bool ) -> None :
157+ """
158+ Onload weight once and run both global scale (gparam) and scale/zp (qparams).
159+ Used in sequential DFS to avoid double onload for NVFP4.
160+ """
161+ if getattr_chain (module , "quantization_scheme.weights" , None ) is None :
162+ return
163+ need_gparam = (
164+ getattr_chain (module , "quantization_scheme.weights.strategy" , None )
165+ == QuantizationStrategy .TENSOR_GROUP
166+ )
167+ need_qparams = update_zp_scale
168+ if not need_gparam and not need_qparams :
169+ return
170+ if (
171+ need_qparams
172+ and getattr (module , "quantization_status" , None )
173+ != QuantizationStatus .CALIBRATION
174+ ):
175+ logger .warning (
176+ "Attempting to calibrate weights of a module not in calibration mode"
177+ )
178+ with align_module_device (module ):
179+ value = module .weight
180+ call_observer (
181+ module ,
182+ base_name = "weight" ,
183+ value = value ,
184+ should_calculate_gparam = need_gparam ,
185+ should_calculate_qparams = need_qparams ,
186+ )
187+
188+
139189def update_weight_zp_scale (module : Module ):
140190 """
141191 marks a layer as ready for calibration which activates observers
@@ -162,84 +212,112 @@ def calibrate_weights(
162212 model : Module ,
163213 * ,
164214 named_modules : Optional [Iterable [Tuple [str , Module ]]] = None ,
165- targets : Optional [ Set [ str ]] = None ,
166- ignore : Optional [ Iterable [str ]] = None ,
215+ targets : Iterable [ str ] = () ,
216+ ignore : Iterable [str ] = () ,
167217 update_zp_scale : bool = True ,
168218 desc : Optional [str ] = "Calibrating weights" ,
169219 show_progress : bool = True ,
220+ parallel : bool = False ,
221+ max_workers : Optional [int ] = None ,
170222) -> None :
171223 """
172- Traverse the model once (DFS) and run weight calibration: global scales for
173- FP4/TENSOR_GROUP, fused layer global scales for Attention/MLP, and weight
174- scale/zero-point. Replaces separate loops over named_modules and
175- model.modules() for better cache locality and fewer CPU–GPU onloads when
176- using offloading.
177-
178- Order of operations per module:
179- 1. Pre-order: update_weight_global_scale for target (quantizable) modules.
180- 2. Post-order: update_fused_layer_weight_global_scales for every module
181- (no-op except for Attention/MLP containers); then update_weight_zp_scale
182- for target modules if update_zp_scale is True.
224+ Run weight calibration: per-tensor global scale (gparam), fused global scales
225+ for Attention/MLP, and scale/zero-point (qparams). Minimizes weight onloads
226+ when using offloading (one onload per target in the default path).
227+
228+ Two modes:
229+ - Sequential (parallel=False): DFS over the model. Pre-order: one onload per
230+ target via _update_weight_calibration_once (gparam + qparams). Post-order:
231+ update_fused_layer_weight_global_scales (no extra onload for targets).
232+ - Parallel (parallel=True): Phase 1 runs gparam + qparams per target
233+ (order-independent, parallelizable). Phase 2 applies fused global scales
234+ and rescales per-tensor scale s' = s * (g' / g).
235+
236+ DDP: Works with distributed setups. Pass named_modules as this rank's
237+ subset so each rank only calibrates its assigned modules (see e.g. #2220).
238+ Activation observer sync across ranks is handled by
239+ QuantizationMixin.sync_activation_observers at layer
240+ boundaries (PR #2391); weight calibration does not all-reduce weight
241+ observer state—each rank calibrates its subset and can broadcast
242+ quantized params afterward (e.g. GPTQ-style) if needed. Fused groups
243+ (q/k/v, gate/up) must be assigned to the same rank so
244+ update_fused_layer_weight_global_scales sees the full group. For
245+ balanced wall time, assign by weight size (e.g. greedy_bin_packing with
246+ item_weight_fn=lambda m: m.weight.numel(); see GPTQ DDP #2333 which uses
247+ hessian shape for the same idea).
248+
249+ Benchmark: See tests/benchmark_calibrate_weights.py for onload count and
250+ single-vs-double-onload timing.
183251
184252 :param model: Root module to traverse (e.g. state.model).
185- :param named_modules: Optional list of (name, module) for target modules.
186- If provided, only these modules get global_scale and zp_scale; enables
187- DDP by passing this rank's subset (see #2220). If None, targets and
188- ignore must be provided and match_named_modules(model, targets, ignore)
189- is used.
190- :param targets: Target module name patterns (used when named_modules is None).
191- :param ignore: Ignore patterns (used when named_modules is None).
192- :param update_zp_scale: If True, call update_weight_zp_scale on target
193- modules in post-order. Set False for modifiers that do zp_scale in
194- hooks (e.g. GPTQ).
195- :param desc: Progress bar description; None to disable progress bar.
196- :param show_progress: If True and desc is not None, show a tqdm progress bar.
253+ :param named_modules: If provided, only these (name, module) pairs are
254+ calibrated; enables DDP by passing this rank's subset. If None, uses
255+ match_named_modules(model, targets, ignore).
256+ :param targets: Name patterns when named_modules is None. Default ().
257+ :param ignore: Ignore patterns when named_modules is None. Default ().
258+ :param update_zp_scale: If True, compute scale/zp for targets. False for
259+ modifiers that do zp in hooks (e.g. GPTQ).
260+ :param desc: Progress bar description; None disables bar.
261+ :param show_progress: If True and desc set, show tqdm bar.
262+ :param parallel: If True, use two-phase parallel calibration.
263+ :param max_workers: If parallel and int, phase 1 uses this many workers.
197264 """
198265 if named_modules is None :
199- if targets is None or ignore is None :
200- raise ValueError (
201- "calibrate_weights requires either named_modules or both "
202- "targets and ignore"
203- )
204- from compressed_tensors .utils import match_named_modules
205-
206266 named_modules = list (match_named_modules (model , targets , ignore ))
207267 else :
208268 named_modules = list (named_modules )
269+ # DDP: target_set = only these get gparam + qparams (this rank's subset).
270+ target_set = {m for _ , m in named_modules }
271+ target_list = list (target_set )
272+ total_targets = len (target_list )
209273
210- target_set = {id (m ) for _ , m in named_modules }
211- total_targets = len (target_set )
212-
213- try :
214- import tqdm
215- except ImportError :
216- tqdm = None
217-
218- if show_progress and desc is not None and tqdm is not None and total_targets > 0 :
274+ if show_progress and desc is not None and total_targets > 0 :
219275 pbar = tqdm .tqdm (total = total_targets , desc = desc )
220276 else :
221277 pbar = None
222278
223- # Stack-based DFS: (module, children_visited)
224- stack : list [Tuple [Module , bool ]] = [(model , False )]
279+ if parallel :
280+ # Phase 1: per-module global scale + scale/zp (order-independent)
281+ pbar_lock = threading .Lock ()
225282
226- while stack :
227- module , children_done = stack .pop ()
283+ def _phase1_one (module : Module ) -> None :
284+ update_weight_global_scale (module )
285+ if update_zp_scale :
286+ update_weight_zp_scale (module )
287+ if pbar is not None :
288+ with pbar_lock :
289+ pbar .update (1 )
228290
229- if not children_done :
230- # Pre-order: global scale for target modules (FP4 / TENSOR_GROUP)
231- if id (module ) in target_set :
232- update_weight_global_scale (module )
233- stack .append ((module , True ))
234- for child in reversed (list (module .children ())):
235- stack .append ((child , False ))
291+ if max_workers is not None and max_workers > 0 :
292+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
293+ list (executor .map (_phase1_one , target_list ))
236294 else :
237- # Post-order: fused global scales (Attention/MLP), then zp_scale for targets
295+ for module in target_list :
296+ _phase1_one (module )
297+
298+ # Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
299+ for module in _post_order_modules (model ):
238300 update_fused_layer_weight_global_scales (module )
239- if update_zp_scale and id (module ) in target_set :
240- update_weight_zp_scale (module )
241- if pbar is not None :
242- pbar .update (1 )
301+ else :
302+ # Sequential DFS: pre-order one onload for gparam + qparams, post-order fused
303+ seen_pre : set [Module ] = set ()
304+ seen_post : set [Module ] = set ()
305+ stack = [(model , False )]
306+ while stack :
307+ module , children_done = stack .pop ()
308+ if not children_done :
309+ if module in target_set and module not in seen_pre :
310+ seen_pre .add (module )
311+ _update_weight_calibration_once (module , update_zp_scale )
312+ stack .append ((module , True ))
313+ for child in reversed (list (module .children ())):
314+ stack .append ((child , False ))
315+ else :
316+ update_fused_layer_weight_global_scales (module )
317+ if update_zp_scale and module in target_set and module not in seen_post :
318+ seen_post .add (module )
319+ if pbar is not None :
320+ pbar .update (1 )
243321
244322 if pbar is not None :
245323 pbar .close ()
0 commit comments