1- from typing import Any , Iterable , Optional , Set , Tuple
1+ import threading
2+ from concurrent .futures import ThreadPoolExecutor
3+ from typing import Any , Iterable , Optional , Tuple
24
35import torch
6+ import tqdm
47from compressed_tensors .quantization import (
58 DynamicType ,
69 QuantizationArgs ,
1114from compressed_tensors .utils import (
1215 align_module_device ,
1316 getattr_chain ,
17+ match_named_modules ,
1418 update_offload_parameter ,
1519)
1620from loguru import logger
@@ -162,11 +166,13 @@ def calibrate_weights(
162166 model : Module ,
163167 * ,
164168 named_modules : Optional [Iterable [Tuple [str , Module ]]] = None ,
165- targets : Optional [ Set [ str ]] = None ,
166- ignore : Optional [ Iterable [str ]] = None ,
169+ targets : Iterable [ str ] = () ,
170+ ignore : Iterable [str ] = () ,
167171 update_zp_scale : bool = True ,
168172 desc : Optional [str ] = "Calibrating weights" ,
169173 show_progress : bool = True ,
174+ parallel : bool = False ,
175+ max_workers : Optional [int ] = None ,
170176) -> None :
171177 """
172178 Traverse the model once (DFS) and run weight calibration: global scales for
@@ -175,72 +181,104 @@ def calibrate_weights(
175181 model.modules() for better cache locality and fewer CPU–GPU onloads when
176182 using offloading.
177183
178- Order of operations per module :
184+ Order of operations (default, parallel=False) :
179185 1. Pre-order: update_weight_global_scale for target (quantizable) modules.
180186 2. Post-order: update_fused_layer_weight_global_scales for every module
181187 (no-op except for Attention/MLP containers); then update_weight_zp_scale
182188 for target modules if update_zp_scale is True.
183189
190+ When parallel=True (parallel weight calibration):
191+ 1. Phase 1: For each target module, run update_weight_global_scale then
192+ update_weight_zp_scale (if update_zp_scale). Order is independent so
193+ phase 1 can be parallelized (e.g. with max_workers).
194+ 2. Phase 2: Traverse model and run update_fused_layer_weight_global_scales
195+ on every module. Fused global scale g' is applied and per-tensor scale
196+ is rescaled s' = s * (g' / g) so that q = x/(s'*g') = x/(s*g) is unchanged.
197+
184198 :param model: Root module to traverse (e.g. state.model).
185199 :param named_modules: Optional list of (name, module) for target modules.
186200 If provided, only these modules get global_scale and zp_scale; enables
187201 DDP by passing this rank's subset (see #2220). If None, targets and
188- ignore must be provided and match_named_modules(model, targets, ignore)
189- is used .
202+ ignore are used via match_named_modules(model, targets, ignore)
203+ (default () for both means no name-based filtering) .
190204 :param targets: Target module name patterns (used when named_modules is None).
191- :param ignore: Ignore patterns (used when named_modules is None).
205+ Default () means no name-based filtering when named_modules is None.
206+ :param ignore: Ignore patterns (used when named_modules is None). Default ().
192207 :param update_zp_scale: If True, call update_weight_zp_scale on target
193208 modules in post-order. Set False for modifiers that do zp_scale in
194209 hooks (e.g. GPTQ).
195210 :param desc: Progress bar description; None to disable progress bar.
196211 :param show_progress: If True and desc is not None, show a tqdm progress bar.
212+ :param parallel: If True, use two-phase parallel calibration (phase 1 per-layer,
213+ phase 2 fused global scales with scale rescaling).
214+ :param max_workers: If parallel=True and int, run phase 1 with this many
215+ workers. If None, phase 1 runs sequentially.
197216 """
198217 if named_modules is None :
199- if targets is None or ignore is None :
200- raise ValueError (
201- "calibrate_weights requires either named_modules or both "
202- "targets and ignore"
203- )
204- from compressed_tensors .utils import match_named_modules
205-
206218 named_modules = list (match_named_modules (model , targets , ignore ))
207219 else :
208220 named_modules = list (named_modules )
209221
210- target_set = {id (m ) for _ , m in named_modules }
211- total_targets = len (target_set )
222+ target_set = {m for _ , m in named_modules }
223+ target_list = list (target_set )
224+ total_targets = len (target_list )
212225
213- try :
214- import tqdm
215- except ImportError :
216- tqdm = None
217-
218- if show_progress and desc is not None and tqdm is not None and total_targets > 0 :
226+ if show_progress and desc is not None and total_targets > 0 :
219227 pbar = tqdm .tqdm (total = total_targets , desc = desc )
220228 else :
221229 pbar = None
222230
223- # Stack-based DFS: (module, children_visited)
224- stack : list [Tuple [Module , bool ]] = [(model , False )]
225-
226- while stack :
227- module , children_done = stack .pop ()
231+ if parallel :
232+ # Phase 1: per-module global scale + scale/zp (order-independent)
233+ pbar_lock = threading .Lock ()
228234
229- if not children_done :
230- # Pre-order: global scale for target modules (FP4 / TENSOR_GROUP)
231- if id (module ) in target_set :
232- update_weight_global_scale (module )
233- stack .append ((module , True ))
234- for child in reversed (list (module .children ())):
235- stack .append ((child , False ))
236- else :
237- # Post-order: fused global scales (Attention/MLP), then zp_scale for targets
238- update_fused_layer_weight_global_scales (module )
239- if update_zp_scale and id (module ) in target_set :
235+ def _phase1_one (module : Module ) -> None :
236+ update_weight_global_scale (module )
237+ if update_zp_scale :
240238 update_weight_zp_scale (module )
241- if pbar is not None :
239+ if pbar is not None :
240+ with pbar_lock :
242241 pbar .update (1 )
243242
243+ if max_workers is not None and max_workers > 0 :
244+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
245+ list (executor .map (_phase1_one , target_list ))
246+ else :
247+ for module in target_list :
248+ _phase1_one (module )
249+
250+ # Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
251+ stack : list [Tuple [Module , bool ]] = [(model , False )]
252+ while stack :
253+ module , children_done = stack .pop ()
254+ if not children_done :
255+ stack .append ((module , True ))
256+ for child in reversed (list (module .children ())):
257+ stack .append ((child , False ))
258+ else :
259+ update_fused_layer_weight_global_scales (module )
260+ else :
261+ # Sequential DFS: pre-order global scale, post-order fused + zp_scale
262+ seen_pre : set [Module ] = set ()
263+ seen_post : set [Module ] = set ()
264+ stack = [(model , False )]
265+ while stack :
266+ module , children_done = stack .pop ()
267+ if not children_done :
268+ if module in target_set and module not in seen_pre :
269+ seen_pre .add (module )
270+ update_weight_global_scale (module )
271+ stack .append ((module , True ))
272+ for child in reversed (list (module .children ())):
273+ stack .append ((child , False ))
274+ else :
275+ update_fused_layer_weight_global_scales (module )
276+ if update_zp_scale and module in target_set and module not in seen_post :
277+ seen_post .add (module )
278+ update_weight_zp_scale (module )
279+ if pbar is not None :
280+ pbar .update (1 )
281+
244282 if pbar is not None :
245283 pbar .close ()
246284
0 commit comments