Skip to content

Commit b3d2de9

Browse files
author
Avishek Goswami
committed
Calibration: move fused_modules to modeling, rescale s when fusing g', add parallel path
- Move fused_modules.py to modeling/ and update imports - In update_fused_layer_weight_global_scales, rescale weight_scale s' = s*g'/g when applying fused global scale so q unchanged - Add calibrate_weights(..., parallel=True, max_workers=N) for two-phase parallel weight calibration Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 480294d commit b3d2de9

File tree

4 files changed

+163
-64
lines changed

4 files changed

+163
-64
lines changed
File renamed without changes.

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 136 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from typing import Any, Iterable, Optional, Set, Tuple
1+
import threading
2+
from concurrent.futures import ThreadPoolExecutor
3+
from typing import Any, Iterable, Iterator, Optional, Tuple
24

35
import torch
6+
import tqdm
47
from compressed_tensors.quantization import (
58
DynamicType,
69
QuantizationArgs,
@@ -11,6 +14,7 @@
1114
from compressed_tensors.utils import (
1215
align_module_device,
1316
getattr_chain,
17+
match_named_modules,
1418
update_offload_parameter,
1519
)
1620
from loguru import logger
@@ -136,6 +140,52 @@ def update_weight_global_scale(module: Module):
136140
)
137141

138142

143+
def _post_order_modules(model: Module) -> Iterator[Module]:
144+
"""Yield every module in the tree in DFS post-order."""
145+
stack: list[Tuple[Module, bool]] = [(model, False)]
146+
while stack:
147+
module, children_done = stack.pop()
148+
if not children_done:
149+
stack.append((module, True))
150+
for child in reversed(list(module.children())):
151+
stack.append((child, False))
152+
else:
153+
yield module
154+
155+
156+
def _update_weight_calibration_once(module: Module, update_zp_scale: bool) -> None:
157+
"""
158+
Onload weight once and run both global scale (gparam) and scale/zp (qparams).
159+
Used in sequential DFS to avoid double onload for NVFP4.
160+
"""
161+
if getattr_chain(module, "quantization_scheme.weights", None) is None:
162+
return
163+
need_gparam = (
164+
getattr_chain(module, "quantization_scheme.weights.strategy", None)
165+
== QuantizationStrategy.TENSOR_GROUP
166+
)
167+
need_qparams = update_zp_scale
168+
if not need_gparam and not need_qparams:
169+
return
170+
if (
171+
need_qparams
172+
and getattr(module, "quantization_status", None)
173+
!= QuantizationStatus.CALIBRATION
174+
):
175+
logger.warning(
176+
"Attempting to calibrate weights of a module not in calibration mode"
177+
)
178+
with align_module_device(module):
179+
value = module.weight
180+
call_observer(
181+
module,
182+
base_name="weight",
183+
value=value,
184+
should_calculate_gparam=need_gparam,
185+
should_calculate_qparams=need_qparams,
186+
)
187+
188+
139189
def update_weight_zp_scale(module: Module):
140190
"""
141191
marks a layer as ready for calibration which activates observers
@@ -162,84 +212,112 @@ def calibrate_weights(
162212
model: Module,
163213
*,
164214
named_modules: Optional[Iterable[Tuple[str, Module]]] = None,
165-
targets: Optional[Set[str]] = None,
166-
ignore: Optional[Iterable[str]] = None,
215+
targets: Iterable[str] = (),
216+
ignore: Iterable[str] = (),
167217
update_zp_scale: bool = True,
168218
desc: Optional[str] = "Calibrating weights",
169219
show_progress: bool = True,
220+
parallel: bool = False,
221+
max_workers: Optional[int] = None,
170222
) -> None:
171223
"""
172-
Traverse the model once (DFS) and run weight calibration: global scales for
173-
FP4/TENSOR_GROUP, fused layer global scales for Attention/MLP, and weight
174-
scale/zero-point. Replaces separate loops over named_modules and
175-
model.modules() for better cache locality and fewer CPU–GPU onloads when
176-
using offloading.
177-
178-
Order of operations per module:
179-
1. Pre-order: update_weight_global_scale for target (quantizable) modules.
180-
2. Post-order: update_fused_layer_weight_global_scales for every module
181-
(no-op except for Attention/MLP containers); then update_weight_zp_scale
182-
for target modules if update_zp_scale is True.
224+
Run weight calibration: per-tensor global scale (gparam), fused global scales
225+
for Attention/MLP, and scale/zero-point (qparams). Minimizes weight onloads
226+
when using offloading (one onload per target in the default path).
227+
228+
Two modes:
229+
- Sequential (parallel=False): DFS over the model. Pre-order: one onload per
230+
target via _update_weight_calibration_once (gparam + qparams). Post-order:
231+
update_fused_layer_weight_global_scales (no extra onload for targets).
232+
- Parallel (parallel=True): Phase 1 runs gparam + qparams per target
233+
(order-independent, parallelizable). Phase 2 applies fused global scales
234+
and rescales per-tensor scale s' = s * (g' / g).
235+
236+
DDP: Works with distributed setups. Pass named_modules as this rank's
237+
subset so each rank only calibrates its assigned modules (see e.g. #2220).
238+
Activation observer sync across ranks is handled by
239+
QuantizationMixin.sync_activation_observers at layer
240+
boundaries (PR #2391); weight calibration does not all-reduce weight
241+
observer state—each rank calibrates its subset and can broadcast
242+
quantized params afterward (e.g. GPTQ-style) if needed. Fused groups
243+
(q/k/v, gate/up) must be assigned to the same rank so
244+
update_fused_layer_weight_global_scales sees the full group. For
245+
balanced wall time, assign by weight size (e.g. greedy_bin_packing with
246+
item_weight_fn=lambda m: m.weight.numel(); see GPTQ DDP #2333 which uses
247+
hessian shape for the same idea).
248+
249+
Benchmark: See tests/benchmark_calibrate_weights.py for onload count and
250+
single-vs-double-onload timing.
183251
184252
:param model: Root module to traverse (e.g. state.model).
185-
:param named_modules: Optional list of (name, module) for target modules.
186-
If provided, only these modules get global_scale and zp_scale; enables
187-
DDP by passing this rank's subset (see #2220). If None, targets and
188-
ignore must be provided and match_named_modules(model, targets, ignore)
189-
is used.
190-
:param targets: Target module name patterns (used when named_modules is None).
191-
:param ignore: Ignore patterns (used when named_modules is None).
192-
:param update_zp_scale: If True, call update_weight_zp_scale on target
193-
modules in post-order. Set False for modifiers that do zp_scale in
194-
hooks (e.g. GPTQ).
195-
:param desc: Progress bar description; None to disable progress bar.
196-
:param show_progress: If True and desc is not None, show a tqdm progress bar.
253+
:param named_modules: If provided, only these (name, module) pairs are
254+
calibrated; enables DDP by passing this rank's subset. If None, uses
255+
match_named_modules(model, targets, ignore).
256+
:param targets: Name patterns when named_modules is None. Default ().
257+
:param ignore: Ignore patterns when named_modules is None. Default ().
258+
:param update_zp_scale: If True, compute scale/zp for targets. False for
259+
modifiers that do zp in hooks (e.g. GPTQ).
260+
:param desc: Progress bar description; None disables bar.
261+
:param show_progress: If True and desc set, show tqdm bar.
262+
:param parallel: If True, use two-phase parallel calibration.
263+
:param max_workers: If parallel and int, phase 1 uses this many workers.
197264
"""
198265
if named_modules is None:
199-
if targets is None or ignore is None:
200-
raise ValueError(
201-
"calibrate_weights requires either named_modules or both "
202-
"targets and ignore"
203-
)
204-
from compressed_tensors.utils import match_named_modules
205-
206266
named_modules = list(match_named_modules(model, targets, ignore))
207267
else:
208268
named_modules = list(named_modules)
269+
# DDP: target_set = only these get gparam + qparams (this rank's subset).
270+
target_set = {m for _, m in named_modules}
271+
target_list = list(target_set)
272+
total_targets = len(target_list)
209273

210-
target_set = {id(m) for _, m in named_modules}
211-
total_targets = len(target_set)
212-
213-
try:
214-
import tqdm
215-
except ImportError:
216-
tqdm = None
217-
218-
if show_progress and desc is not None and tqdm is not None and total_targets > 0:
274+
if show_progress and desc is not None and total_targets > 0:
219275
pbar = tqdm.tqdm(total=total_targets, desc=desc)
220276
else:
221277
pbar = None
222278

223-
# Stack-based DFS: (module, children_visited)
224-
stack: list[Tuple[Module, bool]] = [(model, False)]
279+
if parallel:
280+
# Phase 1: per-module global scale + scale/zp (order-independent)
281+
pbar_lock = threading.Lock()
225282

226-
while stack:
227-
module, children_done = stack.pop()
283+
def _phase1_one(module: Module) -> None:
284+
update_weight_global_scale(module)
285+
if update_zp_scale:
286+
update_weight_zp_scale(module)
287+
if pbar is not None:
288+
with pbar_lock:
289+
pbar.update(1)
228290

229-
if not children_done:
230-
# Pre-order: global scale for target modules (FP4 / TENSOR_GROUP)
231-
if id(module) in target_set:
232-
update_weight_global_scale(module)
233-
stack.append((module, True))
234-
for child in reversed(list(module.children())):
235-
stack.append((child, False))
291+
if max_workers is not None and max_workers > 0:
292+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
293+
list(executor.map(_phase1_one, target_list))
236294
else:
237-
# Post-order: fused global scales (Attention/MLP), then zp_scale for targets
295+
for module in target_list:
296+
_phase1_one(module)
297+
298+
# Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
299+
for module in _post_order_modules(model):
238300
update_fused_layer_weight_global_scales(module)
239-
if update_zp_scale and id(module) in target_set:
240-
update_weight_zp_scale(module)
241-
if pbar is not None:
242-
pbar.update(1)
301+
else:
302+
# Sequential DFS: pre-order one onload for gparam + qparams, post-order fused
303+
seen_pre: set[Module] = set()
304+
seen_post: set[Module] = set()
305+
stack = [(model, False)]
306+
while stack:
307+
module, children_done = stack.pop()
308+
if not children_done:
309+
if module in target_set and module not in seen_pre:
310+
seen_pre.add(module)
311+
_update_weight_calibration_once(module, update_zp_scale)
312+
stack.append((module, True))
313+
for child in reversed(list(module.children())):
314+
stack.append((child, False))
315+
else:
316+
update_fused_layer_weight_global_scales(module)
317+
if update_zp_scale and module in target_set and module not in seen_post:
318+
seen_post.add(module)
319+
if pbar is not None:
320+
pbar.update(1)
243321

244322
if pbar is not None:
245323
pbar.close()

src/llmcompressor/modifiers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ruff: noqa
22

33
from .constants import *
4-
from .fused_modules import (
4+
from llmcompressor.modeling.fused_modules import (
55
get_fused_attention_linears,
66
get_fused_mlp_linears,
77
is_fused_attention_module,

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
import torch
1111
from compressed_tensors.quantization import QuantizationStrategy
12-
from compressed_tensors.utils import align_modules, update_parameter_data
12+
from compressed_tensors.utils import (
13+
align_modules,
14+
update_offload_parameter,
15+
update_parameter_data,
16+
)
1317
from torch.nn import Linear
1418

15-
from llmcompressor.modifiers.utils.fused_modules import (
19+
from llmcompressor.modeling.fused_modules import (
1620
get_fused_attention_linears,
1721
get_fused_mlp_linears,
1822
)
@@ -39,7 +43,12 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
3943
When running NVFP4 quantization, update the global scale so that vLLM
4044
fused groups share one global scale: attention (traditional q/k/v or
4145
MLA q_a + kv_a) and MLP (gate/up). Uses the centralized fused module
42-
definitions; see :mod:`llmcompressor.modifiers.utils.fused_modules`.
46+
definitions; see :mod:`llmcompressor.modeling.fused_modules`.
47+
48+
When a linear already has ``weight_scale`` (e.g. after parallel phase-1
49+
calibration), per-tensor scale is rescaled so that q = x/(s'*g') is
50+
unchanged: s' = s * (g' / g), where g' is the fused global scale and g
51+
was the previous per-tensor global scale.
4352
4453
This is a requirement currently set by vLLM and may be removed or
4554
made optional in the future.
@@ -55,7 +64,7 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
5564
torch.cat([lin.weight_global_scale.data for lin in linears])
5665
).reshape([1])
5766
for lin in linears:
58-
update_parameter_data(lin, global_scale, "weight_global_scale")
67+
_apply_fused_global_scale(lin, global_scale)
5968
del global_scale
6069

6170
# Fused MLP: gate_proj, up_proj
@@ -66,5 +75,17 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
6675
torch.cat([lin.weight_global_scale.data for lin in linears])
6776
).reshape([1])
6877
for lin in linears:
69-
update_parameter_data(lin, global_scale, "weight_global_scale")
78+
_apply_fused_global_scale(lin, global_scale)
7079
del global_scale
80+
81+
82+
def _apply_fused_global_scale(lin: Linear, g_prime: torch.Tensor) -> None:
83+
"""Set weight_global_scale to g'; rescale weight_scale so q = x/(s*g) unchanged."""
84+
old_g = lin.weight_global_scale.data
85+
update_parameter_data(lin, g_prime, "weight_global_scale")
86+
weight_scale = getattr(lin, "weight_scale", None)
87+
if weight_scale is not None:
88+
# s' = s * (g' / g) so that x / s' / g' = x / s / g
89+
ratio = (g_prime / old_g).to(weight_scale.dtype).to(weight_scale.device)
90+
new_scale = weight_scale.data * ratio
91+
update_offload_parameter(lin, "weight_scale", new_scale)

0 commit comments

Comments
 (0)