1- from typing import Any , Iterable , Optional , Set , Tuple
1+ import threading
2+ from concurrent .futures import ThreadPoolExecutor
3+ from typing import Any , Iterable , Optional , Tuple
24
35import torch
6+ import tqdm
47from compressed_tensors .quantization import (
58 DynamicType ,
69 QuantizationArgs ,
1114from compressed_tensors .utils import (
1215 align_module_device ,
1316 getattr_chain ,
17+ match_named_modules ,
1418 update_offload_parameter ,
1519)
1620from loguru import logger
@@ -136,6 +140,39 @@ def update_weight_global_scale(module: Module):
136140 )
137141
138142
143+ def _update_weight_calibration_once (module : Module , update_zp_scale : bool ) -> None :
144+ """
145+ Onload weight once and run both global scale (gparam) and scale/zp (qparams).
146+ Used in sequential DFS to avoid double onload for NVFP4.
147+ """
148+ if getattr_chain (module , "quantization_scheme.weights" , None ) is None :
149+ return
150+ need_gparam = (
151+ getattr_chain (module , "quantization_scheme.weights.strategy" , None )
152+ == QuantizationStrategy .TENSOR_GROUP
153+ )
154+ need_qparams = update_zp_scale
155+ if not need_gparam and not need_qparams :
156+ return
157+ if (
158+ need_qparams
159+ and getattr (module , "quantization_status" , None )
160+ != QuantizationStatus .CALIBRATION
161+ ):
162+ logger .warning (
163+ "Attempting to calibrate weights of a module not in calibration mode"
164+ )
165+ with align_module_device (module ):
166+ value = module .weight
167+ call_observer (
168+ module ,
169+ base_name = "weight" ,
170+ value = value ,
171+ should_calculate_gparam = need_gparam ,
172+ should_calculate_qparams = need_qparams ,
173+ )
174+
175+
139176def update_weight_zp_scale (module : Module ):
140177 """
141178 marks a layer as ready for calibration which activates observers
@@ -162,11 +199,13 @@ def calibrate_weights(
162199 model : Module ,
163200 * ,
164201 named_modules : Optional [Iterable [Tuple [str , Module ]]] = None ,
165- targets : Optional [ Set [ str ]] = None ,
166- ignore : Optional [ Iterable [str ]] = None ,
202+ targets : Iterable [ str ] = () ,
203+ ignore : Iterable [str ] = () ,
167204 update_zp_scale : bool = True ,
168205 desc : Optional [str ] = "Calibrating weights" ,
169206 show_progress : bool = True ,
207+ parallel : bool = False ,
208+ max_workers : Optional [int ] = None ,
170209) -> None :
171210 """
172211 Traverse the model once (DFS) and run weight calibration: global scales for
@@ -175,72 +214,103 @@ def calibrate_weights(
175214 model.modules() for better cache locality and fewer CPU–GPU onloads when
176215 using offloading.
177216
178- Order of operations per module:
179- 1. Pre-order: update_weight_global_scale for target (quantizable) modules.
217+ Order of operations (default, parallel=False):
218+ 1. Pre-order: one weight onload per target module; run both global scale
219+ (gparam) and scale/zp (qparams) via _update_weight_calibration_once.
180220 2. Post-order: update_fused_layer_weight_global_scales for every module
181- (no-op except for Attention/MLP containers); then update_weight_zp_scale
182- for target modules if update_zp_scale is True.
221+ (no-op except for Attention/MLP containers). No second onload for targets.
222+
223+ When parallel=True (parallel weight calibration):
224+ 1. Phase 1: For each target module, run update_weight_global_scale then
225+ update_weight_zp_scale (if update_zp_scale). Order is independent so
226+ phase 1 can be parallelized (e.g. with max_workers).
227+ 2. Phase 2: Traverse model and run update_fused_layer_weight_global_scales
228+ on every module. Fused global scale g' is applied and per-tensor scale
229+ is rescaled s' = s * (g' / g) so that q = x/(s'*g') = x/(s*g) is unchanged.
183230
184231 :param model: Root module to traverse (e.g. state.model).
185232 :param named_modules: Optional list of (name, module) for target modules.
186233 If provided, only these modules get global_scale and zp_scale; enables
187234 DDP by passing this rank's subset (see #2220). If None, targets and
188- ignore must be provided and match_named_modules(model, targets, ignore)
189- is used .
235+ ignore are used via match_named_modules(model, targets, ignore)
236+ (default () for both means no name-based filtering) .
190237 :param targets: Target module name patterns (used when named_modules is None).
191- :param ignore: Ignore patterns (used when named_modules is None).
238+ Default () means no name-based filtering when named_modules is None.
239+ :param ignore: Ignore patterns (used when named_modules is None). Default ().
192240 :param update_zp_scale: If True, call update_weight_zp_scale on target
193241 modules in post-order. Set False for modifiers that do zp_scale in
194242 hooks (e.g. GPTQ).
195243 :param desc: Progress bar description; None to disable progress bar.
196244 :param show_progress: If True and desc is not None, show a tqdm progress bar.
245+ :param parallel: If True, use two-phase parallel calibration (phase 1 per-layer,
246+ phase 2 fused global scales with scale rescaling).
247+ :param max_workers: If parallel=True and int, run phase 1 with this many
248+ workers. If None, phase 1 runs sequentially.
197249 """
198250 if named_modules is None :
199- if targets is None or ignore is None :
200- raise ValueError (
201- "calibrate_weights requires either named_modules or both "
202- "targets and ignore"
203- )
204- from compressed_tensors .utils import match_named_modules
205-
206251 named_modules = list (match_named_modules (model , targets , ignore ))
207252 else :
208253 named_modules = list (named_modules )
209254
210- target_set = {id (m ) for _ , m in named_modules }
211- total_targets = len (target_set )
255+ target_set = {m for _ , m in named_modules }
256+ target_list = list (target_set )
257+ total_targets = len (target_list )
212258
213- try :
214- import tqdm
215- except ImportError :
216- tqdm = None
217-
218- if show_progress and desc is not None and tqdm is not None and total_targets > 0 :
259+ if show_progress and desc is not None and total_targets > 0 :
219260 pbar = tqdm .tqdm (total = total_targets , desc = desc )
220261 else :
221262 pbar = None
222263
223- # Stack-based DFS: (module, children_visited)
224- stack : list [Tuple [Module , bool ]] = [(model , False )]
225-
226- while stack :
227- module , children_done = stack .pop ()
264+ if parallel :
265+ # Phase 1: per-module global scale + scale/zp (order-independent)
266+ pbar_lock = threading .Lock ()
228267
229- if not children_done :
230- # Pre-order: global scale for target modules (FP4 / TENSOR_GROUP)
231- if id (module ) in target_set :
232- update_weight_global_scale (module )
233- stack .append ((module , True ))
234- for child in reversed (list (module .children ())):
235- stack .append ((child , False ))
236- else :
237- # Post-order: fused global scales (Attention/MLP), then zp_scale for targets
238- update_fused_layer_weight_global_scales (module )
239- if update_zp_scale and id (module ) in target_set :
268+ def _phase1_one (module : Module ) -> None :
269+ update_weight_global_scale (module )
270+ if update_zp_scale :
240271 update_weight_zp_scale (module )
241- if pbar is not None :
272+ if pbar is not None :
273+ with pbar_lock :
242274 pbar .update (1 )
243275
276+ if max_workers is not None and max_workers > 0 :
277+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
278+ list (executor .map (_phase1_one , target_list ))
279+ else :
280+ for module in target_list :
281+ _phase1_one (module )
282+
283+ # Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
284+ stack : list [Tuple [Module , bool ]] = [(model , False )]
285+ while stack :
286+ module , children_done = stack .pop ()
287+ if not children_done :
288+ stack .append ((module , True ))
289+ for child in reversed (list (module .children ())):
290+ stack .append ((child , False ))
291+ else :
292+ update_fused_layer_weight_global_scales (module )
293+ else :
294+ # Sequential DFS: pre-order one onload for gparam + qparams, post-order fused
295+ seen_pre : set [Module ] = set ()
296+ seen_post : set [Module ] = set ()
297+ stack = [(model , False )]
298+ while stack :
299+ module , children_done = stack .pop ()
300+ if not children_done :
301+ if module in target_set and module not in seen_pre :
302+ seen_pre .add (module )
303+ _update_weight_calibration_once (module , update_zp_scale )
304+ stack .append ((module , True ))
305+ for child in reversed (list (module .children ())):
306+ stack .append ((child , False ))
307+ else :
308+ update_fused_layer_weight_global_scales (module )
309+ if update_zp_scale and module in target_set and module not in seen_post :
310+ seen_post .add (module )
311+ if pbar is not None :
312+ pbar .update (1 )
313+
244314 if pbar is not None :
245315 pbar .close ()
246316
0 commit comments