Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions src/llmcompressor/modeling/fused_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Central source of truth for vLLM-aligned fused module layouts.

Defines which submodules form a single "fused" group for TENSOR_GROUP (e.g. NVFP4)
global-scale sharing in vLLM. All callers that need to detect or iterate over
fused attention / MLP linears should use the functions in this module.

Fused attention (vLLM)
----------------------
- **Traditional:** Three linears `q_proj`, `k_proj`, `v_proj` that share one
global scale in vLLM.
- **Fused QKV:** Single linear `qkv_proj` (Phi, etc.). No multi-layer fusion
needed; already one tensor.
- **MLA (Multi-head Latent Attention):** Two linears that share one global
scale: query projection and fused KV projection. Common attribute names:
- `q_a_proj` (or `q_proj`) for query
- `kv_a_proj_with_mqa` for key/value (fused KV with MQA).
Used in DeepSeek V2/V3, Kimi K2, Mistral Large 3, and similar architectures.

Fused MLP (vLLM)
----------------
- **Gate/Up:** Two linears `gate_proj`, `up_proj` that share one global scale.
- **Fused Gate-Up:** Single linear `gate_up_proj` (Phi, etc.). No multi-layer
fusion needed.
"""

from __future__ import annotations

from typing import List, Optional

from torch.nn import Linear, Module

__all__ = [
"get_fused_attention_linears",
"get_fused_mlp_linears",
"is_fused_attention_module",
"is_fused_mlp_module",
]


def get_fused_attention_linears(module: Module) -> Optional[List[Linear]]:
"""
Return the list of Linear submodules that form one fused attention group
for vLLM TENSOR_GROUP global scale, or None if this module is not a known
fused attention container.

Definitions (vLLM-aligned):
- **Traditional:** `q_proj`, `k_proj`, `v_proj` (three linears).
- **Fused QKV:** single `qkv_proj` → returns None (no cross-layer fusion).
- **MLA:** `q_a_proj` (or `q_proj`) + `kv_a_proj_with_mqa` (two linears).

:param module: A candidate attention container (e.g. parent of q/k/v or MLA).
:return: List of Linear modules that should share one global scale, or None.
"""
# Already fused as one layer; no cross-layer global scale to apply
if hasattr(module, "qkv_proj"):
return None

# Traditional: q_proj, k_proj, v_proj
if (
hasattr(module, "q_proj")
and hasattr(module, "k_proj")
and hasattr(module, "v_proj")
):
q, k, v = module.q_proj, module.k_proj, module.v_proj
# Avoid treating MLA blocks as traditional (MLA has q_proj + kv_a_proj_with_mqa)
if hasattr(module, "kv_a_proj_with_mqa"):
return None
if isinstance(q, Linear) and isinstance(k, Linear) and isinstance(v, Linear):
return [q, k, v]

# MLA: q_a_proj (or q_proj) + kv_a_proj_with_mqa
if hasattr(module, "kv_a_proj_with_mqa"):
kv = module.kv_a_proj_with_mqa
q_linear = getattr(module, "q_a_proj", None) or getattr(module, "q_proj", None)
if (
q_linear is not None
and isinstance(q_linear, Linear)
and isinstance(kv, Linear)
):
return [q_linear, kv]

return None


def get_fused_mlp_linears(module: Module) -> Optional[List[Linear]]:
"""
Return the list of Linear submodules that form one fused MLP group for
vLLM TENSOR_GROUP global scale, or None if not a known fused MLP container.

Definitions (vLLM-aligned):
- **Gate/Up:** `gate_proj`, `up_proj` (two linears).
- **Fused Gate-Up:** single `gate_up_proj` → returns None (no cross-layer fusion).

:param module: A candidate MLP container (e.g. parent of gate_proj/up_proj).
:return: List of Linear modules that should share one global scale, or None.
"""
# Already fused as one layer
if hasattr(module, "gate_up_proj"):
return None

# Gate/Up: gate_proj, up_proj (require "mlp" in class name to avoid false positives)
if "mlp" not in module.__class__.__name__.lower():
return None
if hasattr(module, "gate_proj") and hasattr(module, "up_proj"):
gate = module.gate_proj
up = module.up_proj
if isinstance(gate, Linear) and isinstance(up, Linear):
return [gate, up]

return None


def is_fused_attention_module(module: Module) -> bool:
"""True if this module is a fused attention container (traditional or MLA)."""
return get_fused_attention_linears(module) is not None


def is_fused_mlp_module(module: Module) -> bool:
"""True if this module is a fused MLP container (gate/up)."""
return get_fused_mlp_linears(module) is not None
23 changes: 8 additions & 15 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
get_layer_mappings_from_architecture,
)
from llmcompressor.modifiers.quantization.calibration import (
calibrate_weights,
call_observer,
update_weight_global_scale,
update_weight_zp_scale,
)
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
Expand Down Expand Up @@ -282,19 +281,13 @@ def on_end(self, state: State, event: Event, **kwargs):
named_modules = list(
match_named_modules(state.model, self.resolved_targets, self.ignore)
)

# For TENSOR_GROUP (nvfp4), calculate global scales after smoothing
for _, module in tqdm(named_modules, desc="Updating global scales"):
update_weight_global_scale(module)

# For TENSOR_GROUP (nvfp4), fuse global scales for attention and MLP layers
# This is a requirement for vLLM inference.
for module in tqdm(state.model.modules(), desc="Fusing global scales"):
update_fused_layer_weight_global_scales(module)

# Calculate scales and zero points using the fused global scales
for _, module in tqdm(named_modules, desc="Calibrating weights"):
update_weight_zp_scale(module)
calibrate_weights(
state.model,
named_modules=named_modules,
update_zp_scale=True,
desc="Calibrating weights",
show_progress=True,
)

QuantizationMixin.end_calibration(self, state.model)

Expand Down
169 changes: 168 additions & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, Optional
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Iterable, Iterator, Optional, Tuple

import torch
import tqdm
from compressed_tensors.quantization import (
DynamicType,
QuantizationArgs,
Expand All @@ -11,14 +14,17 @@
from compressed_tensors.utils import (
align_module_device,
getattr_chain,
match_named_modules,
update_offload_parameter,
)
from loguru import logger
from torch.nn import Module

from llmcompressor.modifiers.utils import update_fused_layer_weight_global_scales
from llmcompressor.observers import Observer

__all__ = [
"calibrate_weights",
"initialize_observer",
"update_weight_zp_scale",
"calibrate_input_hook",
Expand Down Expand Up @@ -134,6 +140,52 @@ def update_weight_global_scale(module: Module):
)


def _post_order_modules(model: Module) -> Iterator[Module]:
"""Yield every module in the tree in DFS post-order."""
stack: list[Tuple[Module, bool]] = [(model, False)]
while stack:
module, children_done = stack.pop()
if not children_done:
stack.append((module, True))
for child in reversed(list(module.children())):
stack.append((child, False))
else:
yield module


def _update_weight_calibration_once(module: Module, update_zp_scale: bool) -> None:
"""
Onload weight once and run both global scale (gparam) and scale/zp (qparams).
Used in sequential DFS to avoid double onload for NVFP4.
"""
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return
need_gparam = (
getattr_chain(module, "quantization_scheme.weights.strategy", None)
== QuantizationStrategy.TENSOR_GROUP
)
need_qparams = update_zp_scale
if not need_gparam and not need_qparams:
return
if (
need_qparams
and getattr(module, "quantization_status", None)
!= QuantizationStatus.CALIBRATION
):
logger.warning(
"Attempting to calibrate weights of a module not in calibration mode"
)
with align_module_device(module):
value = module.weight
call_observer(
module,
base_name="weight",
value=value,
should_calculate_gparam=need_gparam,
should_calculate_qparams=need_qparams,
)


def update_weight_zp_scale(module: Module):
"""
marks a layer as ready for calibration which activates observers
Expand All @@ -156,6 +208,121 @@ def update_weight_zp_scale(module: Module):
call_observer(module=module, base_name="weight")


def calibrate_weights(
Copy link
Collaborator

@kylesayrs kylesayrs Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pre-order and post-order framing is very elegant, but it may get in the way of sharing weight offloading between both calculate_gparam and calculate_qparam.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the pre/post structure was getting in the way of a single onload. I’ve added _update_weight_calibration_once(module, update_zp_scale) which onloads module.weight once and calls call_observer(..., value=value, should_calculate_gparam=..., should_calculate_qparams=...) so both use the same tensor. The sequential DFS now uses this in pre-order for target modules and no longer calls update_weight_zp_scale in post-order for them, so we get one onload per module for NVFP4.

model: Module,
*,
named_modules: Optional[Iterable[Tuple[str, Module]]] = None,
targets: Iterable[str] = (),
ignore: Iterable[str] = (),
update_zp_scale: bool = True,
desc: Optional[str] = "Calibrating weights",
show_progress: bool = True,
parallel: bool = False,
max_workers: Optional[int] = None,
) -> None:
"""
Run weight calibration: per-tensor global scale (gparam), fused global scales
for Attention/MLP, and scale/zero-point (qparams). Minimizes weight onloads
when using offloading (one onload per target in the default path).

Two modes:
- Sequential (parallel=False): DFS over the model. Pre-order: one onload per
target via _update_weight_calibration_once (gparam + qparams). Post-order:
update_fused_layer_weight_global_scales (no extra onload for targets).
- Parallel (parallel=True): Phase 1 runs gparam + qparams per target
(order-independent, parallelizable). Phase 2 applies fused global scales
and rescales per-tensor scale s' = s * (g' / g).

DDP: Works with distributed setups. Pass named_modules as this rank's
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing any of this functionality

subset so each rank only calibrates its assigned modules (see e.g. #2220).
Activation observer sync across ranks is handled by
QuantizationMixin.sync_activation_observers at layer
boundaries (PR #2391); weight calibration does not all-reduce weight
observer state—each rank calibrates its subset and can broadcast
quantized params afterward (e.g. GPTQ-style) if needed. Fused groups
(q/k/v, gate/up) must be assigned to the same rank so
update_fused_layer_weight_global_scales sees the full group. For
balanced wall time, assign by weight size (e.g. greedy_bin_packing with
item_weight_fn=lambda m: m.weight.numel(); see GPTQ DDP #2333 which uses
hessian shape for the same idea).

Benchmark: See tests/benchmark_calibrate_weights.py for onload count and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a real file?

single-vs-double-onload timing.

:param model: Root module to traverse (e.g. state.model).
:param named_modules: If provided, only these (name, module) pairs are
calibrated; enables DDP by passing this rank's subset. If None, uses
match_named_modules(model, targets, ignore).
:param targets: Name patterns when named_modules is None. Default ().
:param ignore: Ignore patterns when named_modules is None. Default ().
:param update_zp_scale: If True, compute scale/zp for targets. False for
modifiers that do zp in hooks (e.g. GPTQ).
:param desc: Progress bar description; None disables bar.
:param show_progress: If True and desc set, show tqdm bar.
:param parallel: If True, use two-phase parallel calibration.
:param max_workers: If parallel and int, phase 1 uses this many workers.
"""
if named_modules is None:
named_modules = list(match_named_modules(model, targets, ignore))
else:
named_modules = list(named_modules)
# DDP: target_set = only these get gparam + qparams (this rank's subset).
target_set = {m for _, m in named_modules}
target_list = list(target_set)
total_targets = len(target_list)

if show_progress and desc is not None and total_targets > 0:
pbar = tqdm.tqdm(total=total_targets, desc=desc)
else:
pbar = None

if parallel:
# Phase 1: per-module global scale + scale/zp (order-independent)
pbar_lock = threading.Lock()

def _phase1_one(module: Module) -> None:
update_weight_global_scale(module)
if update_zp_scale:
update_weight_zp_scale(module)
if pbar is not None:
with pbar_lock:
pbar.update(1)

if max_workers is not None and max_workers > 0:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
list(executor.map(_phase1_one, target_list))
else:
for module in target_list:
_phase1_one(module)

# Phase 2: fused global scales (rescale per-tensor scale s' = s * g' / g)
for module in _post_order_modules(model):
update_fused_layer_weight_global_scales(module)
else:
# Sequential DFS: pre-order one onload for gparam + qparams, post-order fused
seen_pre: set[Module] = set()
seen_post: set[Module] = set()
stack = [(model, False)]
while stack:
module, children_done = stack.pop()
if not children_done:
if module in target_set and module not in seen_pre:
seen_pre.add(module)
_update_weight_calibration_once(module, update_zp_scale)
stack.append((module, True))
for child in reversed(list(module.children())):
stack.append((child, False))
else:
update_fused_layer_weight_global_scales(module)
if update_zp_scale and module in target_set and module not in seen_post:
seen_post.add(module)
if pbar is not None:
Copy link
Collaborator

@HDCharles HDCharles Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this progress bar update should be unindented or else it only updates when zp_scale is hit

pbar.update(1)

if pbar is not None:
pbar.close()


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
"""
Calibrate input or output activations by calling the a module's attached
Expand Down
Loading
Loading