Skip to content

Commit c25f428

Browse files
committed
fix: add norm calibration context for unit-offset RMSNorm (Gemma/Qwen3Next)
Signed-off-by: Gilles Turpin <turpingilles15@gmail.com>
1 parent 026c917 commit c25f428

File tree

4 files changed

+360
-23
lines changed

4 files changed

+360
-23
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from llmcompressor.datasets import get_calibration_dataloader
2424
from llmcompressor.entrypoints.utils import post_process, pre_process
2525
from llmcompressor.modeling.moe_context import moe_calibration_context
26+
from llmcompressor.modeling.offset_norm import norm_calibration_context
2627
from llmcompressor.pipelines import CalibrationPipeline
2728

2829
__all__ = ["Oneshot", "oneshot"]
@@ -217,30 +218,31 @@ def apply_recipe_modifiers(
217218
session.reset()
218219

219220
# (Helen INFERENG-661): validate recipe modifiers before initialization
220-
# Apply MoE calibration context for the entire calibration process
221-
with moe_calibration_context(
222-
self.model,
223-
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
224-
):
225-
session.initialize(
226-
model=self.model,
227-
start=-1,
228-
recipe=self.recipe,
229-
recipe_stage=recipe_stage,
230-
recipe_args=self.recipe_args.recipe_args,
231-
calib_data=calibration_dataloader,
232-
sequential_targets=self.dataset_args.sequential_targets,
233-
)
234-
user_pipeline = self.dataset_args.pipeline
235-
pipeline = CalibrationPipeline.from_modifiers(
236-
session.lifecycle.recipe.modifiers, user=user_pipeline
237-
)
238-
239-
pipeline(
221+
# Apply calibration contexts for the entire calibration process
222+
with norm_calibration_context(self.model):
223+
with moe_calibration_context(
240224
self.model,
241-
calibration_dataloader,
242-
self.dataset_args,
243-
)
225+
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
226+
):
227+
session.initialize(
228+
model=self.model,
229+
start=-1,
230+
recipe=self.recipe,
231+
recipe_stage=recipe_stage,
232+
recipe_args=self.recipe_args.recipe_args,
233+
calib_data=calibration_dataloader,
234+
sequential_targets=self.dataset_args.sequential_targets,
235+
)
236+
user_pipeline = self.dataset_args.pipeline
237+
pipeline = CalibrationPipeline.from_modifiers(
238+
session.lifecycle.recipe.modifiers, user=user_pipeline
239+
)
240+
241+
pipeline(
242+
self.model,
243+
calibration_dataloader,
244+
self.dataset_args,
245+
)
244246

245247
session.finalize()
246248

src/llmcompressor/modeling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .qwen3_5_moe import CalibrationQwen3_5MoeSparseMoeBlock
1919
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
2020
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
21+
from .offset_norm import CalibrationOffsetNorm # noqa: F401
2122
# TODO: add granite4
2223

2324
from .fuse import *
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
"""
2+
Calibration context for offset-norm layers.
3+
4+
Some architectures (Gemma, Qwen3Next) use an offset normalization pattern where
5+
the forward pass computes ``output * (1 + weight)`` instead of the standard
6+
``output * weight``. This breaks any modifier that smooths norm weights
7+
(AWQ, SmoothQuant, SpinQuant) because dividing a (1+weight) parameter by scales
8+
produces incorrect results.
9+
10+
This module provides the infrastructure to temporarily replace offset-norm
11+
modules with standard-norm equivalents during calibration, and restore the
12+
original convention after modifiers have run.
13+
14+
Key components:
15+
- NormCalibrationModule: Abstract base class for norm calibration modules
16+
- norm_calibration_context: Context manager that applies norm conversion
17+
"""
18+
19+
import contextlib
20+
from abc import ABC
21+
22+
import torch
23+
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
24+
from loguru import logger
25+
from transformers import PreTrainedModel
26+
27+
__all__ = [
28+
"NormCalibrationModule",
29+
"norm_calibration_context",
30+
]
31+
32+
33+
class NormCalibrationModule(ABC, torch.nn.Module, RegistryMixin):
34+
"""
35+
Abstract base class for norm calibration modules.
36+
37+
Calibration modules replace original norm modules during the calibration
38+
phase so that modifiers see standard ``output * weight`` semantics.
39+
40+
Subclasses must:
41+
1. Implement ``__init__()`` with signature: ``(self, original, config)``
42+
2. Implement ``restore()`` to convert back to the original norm convention
43+
"""
44+
45+
is_permanent: bool = False
46+
47+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
48+
"""
49+
Restore the original module with updated weights.
50+
51+
Returns:
52+
The original module with weights converted back to offset convention
53+
"""
54+
raise NotImplementedError(
55+
f"{self.__class__.__name__} doesn't implement restore()"
56+
)
57+
58+
59+
@NormCalibrationModule.register(
60+
"GemmaRMSNorm",
61+
alias=["Gemma2RMSNorm", "Gemma3RMSNorm", "Qwen3NextRMSNorm"],
62+
)
63+
class CalibrationOffsetNorm(NormCalibrationModule):
64+
"""
65+
Replaces offset-norm modules (``output * (1 + weight)``) with standard-norm
66+
equivalents (``output * weight``) during calibration.
67+
68+
On enter: ``self.weight = 1 + original.weight``
69+
On restore: ``original.weight = self.weight - 1``
70+
"""
71+
72+
is_permanent = False
73+
74+
def __init__(self, original: torch.nn.Module, config):
75+
super().__init__()
76+
self.eps = original.eps
77+
self._orig_dtype = original.weight.dtype
78+
self.weight = torch.nn.Parameter(1.0 + original.weight.data.float())
79+
80+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
81+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
82+
83+
def forward(self, x: torch.Tensor) -> torch.Tensor:
84+
output = self._norm(x.float())
85+
output = output * self.weight.float()
86+
return output.type_as(x)
87+
88+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
89+
original.weight.data = (self.weight.data - 1.0).to(self._orig_dtype)
90+
return original
91+
92+
93+
@contextlib.contextmanager
94+
def norm_calibration_context(model: PreTrainedModel):
95+
"""
96+
Context manager that converts offset-norm modules to standard-norm.
97+
98+
This scans all modules in the model and replaces any offset-norm modules
99+
(``output * (1 + weight)``) with standard-norm equivalents
100+
(``output * weight``). After the context exits, modules are restored
101+
to their original convention with updated weights.
102+
103+
The model is modified in-place, so the same model object should be used
104+
within the context.
105+
106+
Args:
107+
model: The model to apply norm conversion to (modified in-place)
108+
109+
Example:
110+
with norm_calibration_context(model):
111+
# Modifiers see standard norm weights
112+
run_calibration(model)
113+
# Norms restored to offset convention with smoothed weights
114+
"""
115+
116+
replaced = {}
117+
118+
# Step 1: Collect all offset-norm modules that need replacement
119+
logger.debug("Entering norm calibration context")
120+
modules_to_replace = []
121+
hidden_size = getattr(model.config, "hidden_size", None)
122+
for name, module in model.named_modules():
123+
class_name = module.__class__.__name__
124+
if _is_registered(class_name, NormCalibrationModule):
125+
# Only convert norms operating on hidden_size (e.g. input_layernorm,
126+
# post_attention_layernorm). Skip q_norm/k_norm which operate on
127+
# head_dim — they are not smoothed by any modifier.
128+
if hidden_size is not None and module.weight.shape[0] != hidden_size:
129+
continue
130+
modules_to_replace.append((name, module, class_name))
131+
132+
# Step 2: Replace modules
133+
if modules_to_replace:
134+
logger.info(f"Found {len(modules_to_replace)} offset-norm modules to convert")
135+
for name, module, class_name in modules_to_replace:
136+
replacement = NormCalibrationModule.load_from_registry(
137+
class_name,
138+
original=module,
139+
config=model.config,
140+
)
141+
model.set_submodule(name, replacement)
142+
replaced[name] = (module, replacement)
143+
144+
try:
145+
yield
146+
finally:
147+
# Step 3: Restore original modules with updated weights
148+
if replaced:
149+
logger.info(f"Restoring {len(replaced)} norm modules to offset convention")
150+
for name, (original, replacement) in replaced.items():
151+
restored = replacement.restore(original)
152+
model.set_submodule(name, restored)
153+
154+
155+
def _is_registered(name: str, subclass: RegistryMixin):
156+
lookup = standardize_lookup_name(name)
157+
return (
158+
lookup in subclass.registered_names() or lookup in subclass.registered_aliases()
159+
)

0 commit comments

Comments
 (0)