Skip to content

Commit fae990b

Browse files
author
Avishek Goswami
committed
Calibration: move fused_modules to modeling, rescale s when fusing g', add parallel path
- Move fused_modules.py to modeling/ and update imports - In update_fused_layer_weight_global_scales, rescale weight_scale s' = s*g'/g when applying fused global scale so q unchanged - Add calibrate_weights(..., parallel=True, max_workers=N) for two-phase parallel weight calibration Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 480294d commit fae990b

File tree

4 files changed

+104
-45
lines changed

4 files changed

+104
-45
lines changed
File renamed without changes.

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from typing import Any, Iterable, Optional, Set, Tuple
1+
import threading
2+
from concurrent.futures import ThreadPoolExecutor
3+
from typing import Any, Iterable, Optional, Tuple
24

35
import torch
6+
import tqdm
47
from compressed_tensors.quantization import (
58
DynamicType,
69
QuantizationArgs,
@@ -11,6 +14,7 @@
1114
from compressed_tensors.utils import (
1215
align_module_device,
1316
getattr_chain,
17+
match_named_modules,
1418
update_offload_parameter,
1519
)
1620
from loguru import logger
@@ -162,11 +166,13 @@ def calibrate_weights(
162166
model: Module,
163167
*,
164168
named_modules: Optional[Iterable[Tuple[str, Module]]] = None,
165-
targets: Optional[Set[str]] = None,
166-
ignore: Optional[Iterable[str]] = None,
169+
targets: Iterable[str] = (),
170+
ignore: Iterable[str] = (),
167171
update_zp_scale: bool = True,
168172
desc: Optional[str] = "Calibrating weights",
169173
show_progress: bool = True,
174+
parallel: bool = False,
175+
max_workers: Optional[int] = None,
170176
) -> None:
171177
"""
172178
Traverse the model once (DFS) and run weight calibration: global scales for
@@ -175,72 +181,104 @@ def calibrate_weights(
175181
model.modules() for better cache locality and fewer CPU–GPU onloads when
176182
using offloading.
177183
178-
Order of operations per module:
184+
Order of operations (default, parallel=False):
179185
1. Pre-order: update_weight_global_scale for target (quantizable) modules.
180186
2. Post-order: update_fused_layer_weight_global_scales for every module
181187
(no-op except for Attention/MLP containers); then update_weight_zp_scale
182188
for target modules if update_zp_scale is True.
183189
190+
When parallel=True (parallel weight calibration):
191+
1. Phase 1: For each target module, run update_weight_global_scale then
192+
update_weight_zp_scale (if update_zp_scale). Order is independent so
193+
phase 1 can be parallelized (e.g. with max_workers).
194+
2. Phase 2: Traverse model and run update_fused_layer_weight_global_scales
195+
on every module. Fused global scale g' is applied and per-tensor scale
196+
is rescaled s' = s * (g' / g) so that q = x/(s'*g') = x/(s*g) is unchanged.
197+
184198
:param model: Root module to traverse (e.g. state.model).
185199
:param named_modules: Optional list of (name, module) for target modules.
186200
If provided, only these modules get global_scale and zp_scale; enables
187201
DDP by passing this rank's subset (see #2220). If None, targets and
188-
ignore must be provided and match_named_modules(model, targets, ignore)
189-
is used.
202+
ignore are used via match_named_modules(model, targets, ignore)
203+
(default () for both means no name-based filtering).
190204
:param targets: Target module name patterns (used when named_modules is None).
191-
:param ignore: Ignore patterns (used when named_modules is None).
205+
Default () means no name-based filtering when named_modules is None.
206+
:param ignore: Ignore patterns (used when named_modules is None). Default ().
192207
:param update_zp_scale: If True, call update_weight_zp_scale on target
193208
modules in post-order. Set False for modifiers that do zp_scale in
194209
hooks (e.g. GPTQ).
195210
:param desc: Progress bar description; None to disable progress bar.
196211
:param show_progress: If True and desc is not None, show a tqdm progress bar.
212+
:param parallel: If True, use two-phase parallel calibration (phase 1 per-layer,
213+
phase 2 fused global scales with scale rescaling).
214+
:param max_workers: If parallel=True and int, run phase 1 with this many
215+
workers. If None, phase 1 runs sequentially.
197216
"""
198217
if named_modules is None:
199-
if targets is None or ignore is None:
200-
raise ValueError(
201-
"calibrate_weights requires either named_modules or both "
202-
"targets and ignore"
203-
)
204-
from compressed_tensors.utils import match_named_modules
205-
206218
named_modules = list(match_named_modules(model, targets, ignore))
207219
else:
208220
named_modules = list(named_modules)
209221

210-
target_set = {id(m) for _, m in named_modules}
211-
total_targets = len(target_set)
222+
target_set = {m for _, m in named_modules}
223+
target_list = list(target_set)
224+
total_targets = len(target_list)
212225

213-
try:
214-
import tqdm
215-
except ImportError:
216-
tqdm = None
217-
218-
if show_progress and desc is not None and tqdm is not None and total_targets > 0:
226+
if show_progress and desc is not None and total_targets > 0:
219227
pbar = tqdm.tqdm(total=total_targets, desc=desc)
220228
else:
221229
pbar = None
222230

223-
# Stack-based DFS: (module, children_visited)
224-
stack: list[Tuple[Module, bool]] = [(model, False)]
225-
226-
while stack:
227-
module, children_done = stack.pop()
231+
if parallel:
232+
# Phase 1: per-module global scale + scale/zp (order-independent)
233+
pbar_lock = threading.Lock()
228234

229-
if not children_done:
230-
# Pre-order: global scale for target modules (FP4 / TENSOR_GROUP)
231-
if id(module) in target_set:
232-
update_weight_global_scale(module)
233-
stack.append((module, True))
234-
for child in reversed(list(module.children())):
235-
stack.append((child, False))
236-
else:
237-
# Post-order: fused global scales (Attention/MLP), then zp_scale for targets
238-
update_fused_layer_weight_global_scales(module)
239-
if update_zp_scale and id(module) in target_set:
235+
def _phase1_one(module: Module) -> None:
236+
update_weight_global_scale(module)
237+
if update_zp_scale:
240238
update_weight_zp_scale(module)
241-
if pbar is not None:
239+
if pbar is not None:
240+
with pbar_lock:
242241
pbar.update(1)
243242

243+
if max_workers is not None and max_workers > 0:
244+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
245+
list(executor.map(_phase1_one, target_list))
246+
else:
247+
for module in target_list:
248+
_phase1_one(module)
249+
250+
# Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
251+
stack: list[Tuple[Module, bool]] = [(model, False)]
252+
while stack:
253+
module, children_done = stack.pop()
254+
if not children_done:
255+
stack.append((module, True))
256+
for child in reversed(list(module.children())):
257+
stack.append((child, False))
258+
else:
259+
update_fused_layer_weight_global_scales(module)
260+
else:
261+
# Sequential DFS: pre-order global scale, post-order fused + zp_scale
262+
seen_pre: set[Module] = set()
263+
seen_post: set[Module] = set()
264+
stack = [(model, False)]
265+
while stack:
266+
module, children_done = stack.pop()
267+
if not children_done:
268+
if module in target_set and module not in seen_pre:
269+
seen_pre.add(module)
270+
update_weight_global_scale(module)
271+
stack.append((module, True))
272+
for child in reversed(list(module.children())):
273+
stack.append((child, False))
274+
else:
275+
update_fused_layer_weight_global_scales(module)
276+
if update_zp_scale and module in target_set and module not in seen_post:
277+
seen_post.add(module)
278+
update_weight_zp_scale(module)
279+
if pbar is not None:
280+
pbar.update(1)
281+
244282
if pbar is not None:
245283
pbar.close()
246284

src/llmcompressor/modifiers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ruff: noqa
22

33
from .constants import *
4-
from .fused_modules import (
4+
from llmcompressor.modeling.fused_modules import (
55
get_fused_attention_linears,
66
get_fused_mlp_linears,
77
is_fused_attention_module,

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
import torch
1111
from compressed_tensors.quantization import QuantizationStrategy
12-
from compressed_tensors.utils import align_modules, update_parameter_data
12+
from compressed_tensors.utils import (
13+
align_modules,
14+
update_offload_parameter,
15+
update_parameter_data,
16+
)
1317
from torch.nn import Linear
1418

15-
from llmcompressor.modifiers.utils.fused_modules import (
19+
from llmcompressor.modeling.fused_modules import (
1620
get_fused_attention_linears,
1721
get_fused_mlp_linears,
1822
)
@@ -39,7 +43,12 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
3943
When running NVFP4 quantization, update the global scale so that vLLM
4044
fused groups share one global scale: attention (traditional q/k/v or
4145
MLA q_a + kv_a) and MLP (gate/up). Uses the centralized fused module
42-
definitions; see :mod:`llmcompressor.modifiers.utils.fused_modules`.
46+
definitions; see :mod:`llmcompressor.modeling.fused_modules`.
47+
48+
When a linear already has ``weight_scale`` (e.g. after parallel phase-1
49+
calibration), per-tensor scale is rescaled so that q = x/(s'*g') is
50+
unchanged: s' = s * (g' / g), where g' is the fused global scale and g
51+
was the previous per-tensor global scale.
4352
4453
This is a requirement currently set by vLLM and may be removed or
4554
made optional in the future.
@@ -55,7 +64,7 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
5564
torch.cat([lin.weight_global_scale.data for lin in linears])
5665
).reshape([1])
5766
for lin in linears:
58-
update_parameter_data(lin, global_scale, "weight_global_scale")
67+
_apply_fused_global_scale(lin, global_scale)
5968
del global_scale
6069

6170
# Fused MLP: gate_proj, up_proj
@@ -66,5 +75,17 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
6675
torch.cat([lin.weight_global_scale.data for lin in linears])
6776
).reshape([1])
6877
for lin in linears:
69-
update_parameter_data(lin, global_scale, "weight_global_scale")
78+
_apply_fused_global_scale(lin, global_scale)
7079
del global_scale
80+
81+
82+
def _apply_fused_global_scale(lin: Linear, g_prime: torch.Tensor) -> None:
83+
"""Set weight_global_scale to g'; rescale weight_scale so q = x/(s*g) unchanged."""
84+
old_g = lin.weight_global_scale.data
85+
update_parameter_data(lin, g_prime, "weight_global_scale")
86+
weight_scale = getattr(lin, "weight_scale", None)
87+
if weight_scale is not None:
88+
# s' = s * (g' / g) so that x / s' / g' = x / s / g
89+
ratio = (g_prime / old_g).to(weight_scale.dtype).to(weight_scale.device)
90+
new_scale = weight_scale.data * ratio
91+
update_offload_parameter(lin, "weight_scale", new_scale)

0 commit comments

Comments
 (0)