Skip to content

Commit 6e7a4f5

Browse files
author
Avishek Goswami
committed
Calibration: move fused_modules to modeling, rescale s when fusing g', add parallel path
- Move fused_modules.py to modeling/ and update imports - In update_fused_layer_weight_global_scales, rescale weight_scale s' = s*g'/g when applying fused global scale so q unchanged - Add calibrate_weights(..., parallel=True, max_workers=N) for two-phase parallel weight calibration Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
1 parent 480294d commit 6e7a4f5

File tree

4 files changed

+139
-48
lines changed

4 files changed

+139
-48
lines changed
File renamed without changes.

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 112 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from typing import Any, Iterable, Optional, Set, Tuple
1+
import threading
2+
from concurrent.futures import ThreadPoolExecutor
3+
from typing import Any, Iterable, Optional, Tuple
24

35
import torch
6+
import tqdm
47
from compressed_tensors.quantization import (
58
DynamicType,
69
QuantizationArgs,
@@ -11,6 +14,7 @@
1114
from compressed_tensors.utils import (
1215
align_module_device,
1316
getattr_chain,
17+
match_named_modules,
1418
update_offload_parameter,
1519
)
1620
from loguru import logger
@@ -136,6 +140,39 @@ def update_weight_global_scale(module: Module):
136140
)
137141

138142

143+
def _update_weight_calibration_once(module: Module, update_zp_scale: bool) -> None:
144+
"""
145+
Onload weight once and run both global scale (gparam) and scale/zp (qparams).
146+
Used in sequential DFS to avoid double onload for NVFP4.
147+
"""
148+
if getattr_chain(module, "quantization_scheme.weights", None) is None:
149+
return
150+
need_gparam = (
151+
getattr_chain(module, "quantization_scheme.weights.strategy", None)
152+
== QuantizationStrategy.TENSOR_GROUP
153+
)
154+
need_qparams = update_zp_scale
155+
if not need_gparam and not need_qparams:
156+
return
157+
if (
158+
need_qparams
159+
and getattr(module, "quantization_status", None)
160+
!= QuantizationStatus.CALIBRATION
161+
):
162+
logger.warning(
163+
"Attempting to calibrate weights of a module not in calibration mode"
164+
)
165+
with align_module_device(module):
166+
value = module.weight
167+
call_observer(
168+
module,
169+
base_name="weight",
170+
value=value,
171+
should_calculate_gparam=need_gparam,
172+
should_calculate_qparams=need_qparams,
173+
)
174+
175+
139176
def update_weight_zp_scale(module: Module):
140177
"""
141178
marks a layer as ready for calibration which activates observers
@@ -162,11 +199,13 @@ def calibrate_weights(
162199
model: Module,
163200
*,
164201
named_modules: Optional[Iterable[Tuple[str, Module]]] = None,
165-
targets: Optional[Set[str]] = None,
166-
ignore: Optional[Iterable[str]] = None,
202+
targets: Iterable[str] = (),
203+
ignore: Iterable[str] = (),
167204
update_zp_scale: bool = True,
168205
desc: Optional[str] = "Calibrating weights",
169206
show_progress: bool = True,
207+
parallel: bool = False,
208+
max_workers: Optional[int] = None,
170209
) -> None:
171210
"""
172211
Traverse the model once (DFS) and run weight calibration: global scales for
@@ -175,72 +214,103 @@ def calibrate_weights(
175214
model.modules() for better cache locality and fewer CPU–GPU onloads when
176215
using offloading.
177216
178-
Order of operations per module:
179-
1. Pre-order: update_weight_global_scale for target (quantizable) modules.
217+
Order of operations (default, parallel=False):
218+
1. Pre-order: one weight onload per target module; run both global scale
219+
(gparam) and scale/zp (qparams) via _update_weight_calibration_once.
180220
2. Post-order: update_fused_layer_weight_global_scales for every module
181-
(no-op except for Attention/MLP containers); then update_weight_zp_scale
182-
for target modules if update_zp_scale is True.
221+
(no-op except for Attention/MLP containers). No second onload for targets.
222+
223+
When parallel=True (parallel weight calibration):
224+
1. Phase 1: For each target module, run update_weight_global_scale then
225+
update_weight_zp_scale (if update_zp_scale). Order is independent so
226+
phase 1 can be parallelized (e.g. with max_workers).
227+
2. Phase 2: Traverse model and run update_fused_layer_weight_global_scales
228+
on every module. Fused global scale g' is applied and per-tensor scale
229+
is rescaled s' = s * (g' / g) so that q = x/(s'*g') = x/(s*g) is unchanged.
183230
184231
:param model: Root module to traverse (e.g. state.model).
185232
:param named_modules: Optional list of (name, module) for target modules.
186233
If provided, only these modules get global_scale and zp_scale; enables
187234
DDP by passing this rank's subset (see #2220). If None, targets and
188-
ignore must be provided and match_named_modules(model, targets, ignore)
189-
is used.
235+
ignore are used via match_named_modules(model, targets, ignore)
236+
(default () for both means no name-based filtering).
190237
:param targets: Target module name patterns (used when named_modules is None).
191-
:param ignore: Ignore patterns (used when named_modules is None).
238+
Default () means no name-based filtering when named_modules is None.
239+
:param ignore: Ignore patterns (used when named_modules is None). Default ().
192240
:param update_zp_scale: If True, call update_weight_zp_scale on target
193241
modules in post-order. Set False for modifiers that do zp_scale in
194242
hooks (e.g. GPTQ).
195243
:param desc: Progress bar description; None to disable progress bar.
196244
:param show_progress: If True and desc is not None, show a tqdm progress bar.
245+
:param parallel: If True, use two-phase parallel calibration (phase 1 per-layer,
246+
phase 2 fused global scales with scale rescaling).
247+
:param max_workers: If parallel=True and int, run phase 1 with this many
248+
workers. If None, phase 1 runs sequentially.
197249
"""
198250
if named_modules is None:
199-
if targets is None or ignore is None:
200-
raise ValueError(
201-
"calibrate_weights requires either named_modules or both "
202-
"targets and ignore"
203-
)
204-
from compressed_tensors.utils import match_named_modules
205-
206251
named_modules = list(match_named_modules(model, targets, ignore))
207252
else:
208253
named_modules = list(named_modules)
209254

210-
target_set = {id(m) for _, m in named_modules}
211-
total_targets = len(target_set)
255+
target_set = {m for _, m in named_modules}
256+
target_list = list(target_set)
257+
total_targets = len(target_list)
212258

213-
try:
214-
import tqdm
215-
except ImportError:
216-
tqdm = None
217-
218-
if show_progress and desc is not None and tqdm is not None and total_targets > 0:
259+
if show_progress and desc is not None and total_targets > 0:
219260
pbar = tqdm.tqdm(total=total_targets, desc=desc)
220261
else:
221262
pbar = None
222263

223-
# Stack-based DFS: (module, children_visited)
224-
stack: list[Tuple[Module, bool]] = [(model, False)]
225-
226-
while stack:
227-
module, children_done = stack.pop()
264+
if parallel:
265+
# Phase 1: per-module global scale + scale/zp (order-independent)
266+
pbar_lock = threading.Lock()
228267

229-
if not children_done:
230-
# Pre-order: global scale for target modules (FP4 / TENSOR_GROUP)
231-
if id(module) in target_set:
232-
update_weight_global_scale(module)
233-
stack.append((module, True))
234-
for child in reversed(list(module.children())):
235-
stack.append((child, False))
236-
else:
237-
# Post-order: fused global scales (Attention/MLP), then zp_scale for targets
238-
update_fused_layer_weight_global_scales(module)
239-
if update_zp_scale and id(module) in target_set:
268+
def _phase1_one(module: Module) -> None:
269+
update_weight_global_scale(module)
270+
if update_zp_scale:
240271
update_weight_zp_scale(module)
241-
if pbar is not None:
272+
if pbar is not None:
273+
with pbar_lock:
242274
pbar.update(1)
243275

276+
if max_workers is not None and max_workers > 0:
277+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
278+
list(executor.map(_phase1_one, target_list))
279+
else:
280+
for module in target_list:
281+
_phase1_one(module)
282+
283+
# Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
284+
stack: list[Tuple[Module, bool]] = [(model, False)]
285+
while stack:
286+
module, children_done = stack.pop()
287+
if not children_done:
288+
stack.append((module, True))
289+
for child in reversed(list(module.children())):
290+
stack.append((child, False))
291+
else:
292+
update_fused_layer_weight_global_scales(module)
293+
else:
294+
# Sequential DFS: pre-order one onload for gparam + qparams, post-order fused
295+
seen_pre: set[Module] = set()
296+
seen_post: set[Module] = set()
297+
stack = [(model, False)]
298+
while stack:
299+
module, children_done = stack.pop()
300+
if not children_done:
301+
if module in target_set and module not in seen_pre:
302+
seen_pre.add(module)
303+
_update_weight_calibration_once(module, update_zp_scale)
304+
stack.append((module, True))
305+
for child in reversed(list(module.children())):
306+
stack.append((child, False))
307+
else:
308+
update_fused_layer_weight_global_scales(module)
309+
if update_zp_scale and module in target_set and module not in seen_post:
310+
seen_post.add(module)
311+
if pbar is not None:
312+
pbar.update(1)
313+
244314
if pbar is not None:
245315
pbar.close()
246316

src/llmcompressor/modifiers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# ruff: noqa
22

33
from .constants import *
4-
from .fused_modules import (
4+
from llmcompressor.modeling.fused_modules import (
55
get_fused_attention_linears,
66
get_fused_mlp_linears,
77
is_fused_attention_module,

src/llmcompressor/modifiers/utils/helpers.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
import torch
1111
from compressed_tensors.quantization import QuantizationStrategy
12-
from compressed_tensors.utils import align_modules, update_parameter_data
12+
from compressed_tensors.utils import (
13+
align_modules,
14+
update_offload_parameter,
15+
update_parameter_data,
16+
)
1317
from torch.nn import Linear
1418

15-
from llmcompressor.modifiers.utils.fused_modules import (
19+
from llmcompressor.modeling.fused_modules import (
1620
get_fused_attention_linears,
1721
get_fused_mlp_linears,
1822
)
@@ -39,7 +43,12 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
3943
When running NVFP4 quantization, update the global scale so that vLLM
4044
fused groups share one global scale: attention (traditional q/k/v or
4145
MLA q_a + kv_a) and MLP (gate/up). Uses the centralized fused module
42-
definitions; see :mod:`llmcompressor.modifiers.utils.fused_modules`.
46+
definitions; see :mod:`llmcompressor.modeling.fused_modules`.
47+
48+
When a linear already has ``weight_scale`` (e.g. after parallel phase-1
49+
calibration), per-tensor scale is rescaled so that q = x/(s'*g') is
50+
unchanged: s' = s * (g' / g), where g' is the fused global scale and g
51+
was the previous per-tensor global scale.
4352
4453
This is a requirement currently set by vLLM and may be removed or
4554
made optional in the future.
@@ -55,7 +64,7 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
5564
torch.cat([lin.weight_global_scale.data for lin in linears])
5665
).reshape([1])
5766
for lin in linears:
58-
update_parameter_data(lin, global_scale, "weight_global_scale")
67+
_apply_fused_global_scale(lin, global_scale)
5968
del global_scale
6069

6170
# Fused MLP: gate_proj, up_proj
@@ -66,5 +75,17 @@ def update_fused_layer_weight_global_scales(submodule: torch.nn.Module):
6675
torch.cat([lin.weight_global_scale.data for lin in linears])
6776
).reshape([1])
6877
for lin in linears:
69-
update_parameter_data(lin, global_scale, "weight_global_scale")
78+
_apply_fused_global_scale(lin, global_scale)
7079
del global_scale
80+
81+
82+
def _apply_fused_global_scale(lin: Linear, g_prime: torch.Tensor) -> None:
83+
"""Set weight_global_scale to g'; rescale weight_scale so q = x/(s*g) unchanged."""
84+
old_g = lin.weight_global_scale.data
85+
update_parameter_data(lin, g_prime, "weight_global_scale")
86+
weight_scale = getattr(lin, "weight_scale", None)
87+
if weight_scale is not None:
88+
# s' = s * (g' / g) so that x / s' / g' = x / s / g
89+
ratio = (g_prime / old_g).to(weight_scale.dtype).to(weight_scale.device)
90+
new_scale = weight_scale.data * ratio
91+
update_offload_parameter(lin, "weight_scale", new_scale)

0 commit comments

Comments
 (0)