Skip to content

Commit cc6a964

Browse files
fix: add norm calibration context for unit-offset RMSNorm (Gemma/Qwen3Next) (#2500)
## SUMMARY Some architectures (Gemma, Gemma2, Gemma3, Qwen3Next) use an offset normalization where the forward computes `output * (1 + weight)` instead of `output * weight`. This breaks any modifier that smooths norm weights (AWQ, SmoothQuant, SpinQuant, QuIP) because dividing a `(1+weight)` parameter by scales produces `1 + weight/scales` instead of `(1 + weight)/scales`. Following @brian-dellabetta's suggestion, this adds a `norm_calibration_context` that temporarily replaces offset-norm modules with standard-norm equivalents during calibration, following the same pattern as `moe_calibration_context`. On entry, offset norms are replaced with `CalibrationOffsetNorm` modules (`weight = 1 + original`). On exit, modules are restored with updated weights (`weight = smoothed - 1`). Only norms operating on `hidden_size` are converted. Norms operating on `head_dim` (e.g. `q_norm`/`k_norm` in Gemma3 attention) are skipped since no modifier smooths them. ## TEST PLAN Unit tests (8/8 passing): - Weight conversion and dtype preservation - Forward equivalence with original norm - Restore roundtrip (with and without smoothing) - Registry detection (positive and negative) - `hidden_size` filter: `q_norm`/`k_norm` correctly skipped E2E validation: | Model | Modifier | Norms converted | Output | |---|---|---|---| | `google/gemma-2-2b-it` | AWQ W4A16 | 105 | Coherent | | `google/medgemma-27b-text-it` | AWQ W4A16 | 249 (373 total, 124 q/k skipped) | Coherent | | upstream (no fix) on medgemma | AWQ W4A16 | 0 | Garbage | Qwen3-Next architecture verified structurally: `hidden_size=2048`, `head_dim=256`, `Qwen3NextRMSNorm` uses same `(1+weight)` pattern. No smaller Qwen3-Next model exists for e2e testing (80B MoE only). Fixes #2365 Fixes #2102 Related to #2202 Related to #2059 Signed-off-by: Gilles Turpin <turpingilles15@gmail.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com>
1 parent 4e6aa76 commit cc6a964

File tree

4 files changed

+355
-2
lines changed

4 files changed

+355
-2
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from llmcompressor.datasets import get_calibration_dataloader
2424
from llmcompressor.entrypoints.utils import post_process, pre_process
2525
from llmcompressor.modeling.moe_context import moe_calibration_context
26+
from llmcompressor.modeling.offset_norm import norm_calibration_context
2627
from llmcompressor.pipelines import CalibrationPipeline
2728

2829
__all__ = ["Oneshot", "oneshot"]
@@ -217,8 +218,8 @@ def apply_recipe_modifiers(
217218
session.reset()
218219

219220
# (Helen INFERENG-661): validate recipe modifiers before initialization
220-
# Apply MoE calibration context for the entire calibration process
221-
with moe_calibration_context(
221+
# Apply calibration contexts for the entire calibration process
222+
with norm_calibration_context(self.model), moe_calibration_context(
222223
self.model,
223224
calibrate_all_experts=self.dataset_args.moe_calibrate_all_experts,
224225
):

src/llmcompressor/modeling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .qwen3_5_moe import CalibrationQwen3_5MoeSparseMoeBlock
1919
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
2020
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
21+
from .offset_norm import CalibrationOffsetNorm # noqa: F401
2122
# TODO: add granite4
2223

2324
from .fuse import *
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
Calibration context for offset-norm layers.
3+
4+
Some architectures (Gemma, Qwen3Next) use an offset normalization pattern where
5+
the forward pass computes ``output * (1 + weight)`` instead of the standard
6+
``output * weight``. This breaks any modifier that smooths norm weights
7+
(AWQ, SmoothQuant, SpinQuant) because dividing a (1+weight) parameter by scales
8+
produces incorrect results.
9+
10+
This module provides the infrastructure to temporarily replace offset-norm
11+
modules with standard-norm equivalents during calibration, and restore the
12+
original convention after modifiers have run.
13+
14+
Key components:
15+
- NormCalibrationModule: Abstract base class for norm calibration modules
16+
- norm_calibration_context: Context manager that applies norm conversion
17+
"""
18+
19+
import contextlib
20+
from abc import ABC, abstractmethod
21+
22+
import torch
23+
from compressed_tensors.registry import RegistryMixin, standardize_lookup_name
24+
from loguru import logger
25+
from transformers import PreTrainedModel
26+
27+
__all__ = [
28+
"NormCalibrationModule",
29+
"norm_calibration_context",
30+
]
31+
32+
33+
class NormCalibrationModule(ABC, torch.nn.Module, RegistryMixin):
34+
"""
35+
Abstract base class for norm calibration modules.
36+
37+
Calibration modules replace original norm modules during the calibration
38+
phase so that modifiers see standard ``output * weight`` semantics.
39+
"""
40+
41+
is_permanent: bool = False
42+
43+
@abstractmethod
44+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
45+
"""
46+
Restore the original module with updated weights.
47+
48+
Returns:
49+
The original module with weights converted back to offset convention
50+
"""
51+
...
52+
53+
54+
@NormCalibrationModule.register(
55+
"GemmaRMSNorm",
56+
alias=["Gemma2RMSNorm", "Gemma3RMSNorm", "Qwen3NextRMSNorm"],
57+
)
58+
class CalibrationOffsetNorm(NormCalibrationModule):
59+
"""
60+
Replaces offset-norm modules (``output * (1 + weight)``) with standard-norm
61+
equivalents (``output * weight``) during calibration.
62+
63+
On enter: ``self.weight = 1 + original.weight``
64+
On restore: ``original.weight = self.weight - 1``
65+
"""
66+
67+
is_permanent = False
68+
69+
def __init__(self, original: torch.nn.Module, config):
70+
super().__init__()
71+
self.eps = original.eps
72+
self._orig_dtype = original.weight.dtype
73+
self.weight = torch.nn.Parameter(
74+
(1.0 + original.weight.data.float()).to(original.weight.dtype)
75+
)
76+
77+
def _norm(self, x: torch.Tensor) -> torch.Tensor:
78+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
79+
80+
def forward(self, x: torch.Tensor) -> torch.Tensor:
81+
output = self._norm(x.float())
82+
output = output * self.weight.float()
83+
return output.type_as(x)
84+
85+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
86+
original.weight.data = (self.weight.data.float() - 1.0).to(self._orig_dtype)
87+
return original
88+
89+
90+
@contextlib.contextmanager
91+
def norm_calibration_context(model: PreTrainedModel):
92+
"""
93+
Context manager that converts offset-norm modules to standard-norm.
94+
95+
This scans all modules in the model and replaces any offset-norm modules
96+
(``output * (1 + weight)``) with standard-norm equivalents
97+
(``output * weight``). After the context exits, modules are restored
98+
to their original convention with updated weights.
99+
100+
The model is modified in-place, so the same model object should be used
101+
within the context.
102+
103+
Args:
104+
model: The model to apply norm conversion to (modified in-place)
105+
106+
Example:
107+
with norm_calibration_context(model):
108+
# Modifiers see standard norm weights
109+
run_calibration(model)
110+
# Norms restored to offset convention with smoothed weights
111+
"""
112+
113+
replaced = {}
114+
115+
# Step 1: Collect all offset-norm modules that need replacement
116+
logger.debug("Entering norm calibration context")
117+
modules_to_replace = []
118+
for name, module in model.named_modules():
119+
class_name = module.__class__.__name__
120+
if _is_registered(class_name, NormCalibrationModule):
121+
modules_to_replace.append((name, module, class_name))
122+
123+
# Step 2: Replace modules
124+
if modules_to_replace:
125+
logger.info(f"Found {len(modules_to_replace)} offset-norm modules to convert")
126+
for name, module, class_name in modules_to_replace:
127+
replacement = NormCalibrationModule.load_from_registry(
128+
class_name,
129+
original=module,
130+
config=model.config,
131+
)
132+
model.set_submodule(name, replacement)
133+
replaced[name] = (module, replacement)
134+
135+
try:
136+
yield
137+
finally:
138+
# Step 3: Restore original modules with updated weights
139+
if replaced:
140+
logger.info(f"Restoring {len(replaced)} norm modules to offset convention")
141+
for name, (original, replacement) in replaced.items():
142+
restored = replacement.restore(original)
143+
model.set_submodule(name, restored)
144+
145+
146+
def _is_registered(name: str, subclass: RegistryMixin):
147+
lookup = standardize_lookup_name(name)
148+
return (
149+
lookup in subclass.registered_names() or lookup in subclass.registered_aliases()
150+
)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from types import SimpleNamespace
2+
3+
import pytest
4+
import torch
5+
from torch import nn
6+
7+
from llmcompressor.modeling.offset_norm import (
8+
CalibrationOffsetNorm,
9+
NormCalibrationModule,
10+
norm_calibration_context,
11+
)
12+
13+
# ---------------------------------------------------------------------------
14+
# Mock offset-norm module matching Gemma's (1 + weight) convention
15+
# ---------------------------------------------------------------------------
16+
17+
18+
class FakeGemmaRMSNorm(nn.Module):
19+
"""Minimal mock matching the GemmaRMSNorm forward: output * (1 + weight)"""
20+
21+
def __init__(self, dim, eps=1e-6, dtype=torch.bfloat16):
22+
super().__init__()
23+
self.eps = eps
24+
self.weight = nn.Parameter(torch.zeros(dim, dtype=dtype))
25+
26+
def _norm(self, x):
27+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
28+
29+
def forward(self, x):
30+
output = self._norm(x.float())
31+
output = output * (1.0 + self.weight.float())
32+
return output.type_as(x)
33+
34+
35+
# Patch class name so the registry picks it up
36+
FakeGemmaRMSNorm.__name__ = "GemmaRMSNorm"
37+
FakeGemmaRMSNorm.__qualname__ = "GemmaRMSNorm"
38+
39+
40+
# ---------------------------------------------------------------------------
41+
# Tests
42+
# ---------------------------------------------------------------------------
43+
44+
45+
@pytest.mark.unit
46+
class TestCalibrationOffsetNormInit:
47+
"""Test that __init__ converts weights and stores dtype."""
48+
49+
def test_weight_conversion(self):
50+
original = FakeGemmaRMSNorm(dim=4)
51+
original.weight.data = torch.tensor([0.1, -0.05, 0.0, 0.2])
52+
calib = CalibrationOffsetNorm(original, config=None)
53+
54+
expected = torch.tensor([1.1, 0.95, 1.0, 1.2])
55+
assert torch.allclose(calib.weight.data, expected)
56+
57+
def test_dtype_stored(self):
58+
original = FakeGemmaRMSNorm(dim=4, dtype=torch.bfloat16)
59+
calib = CalibrationOffsetNorm(original, config=None)
60+
61+
assert calib._orig_dtype == torch.bfloat16
62+
assert calib.weight.dtype == torch.bfloat16
63+
64+
65+
@pytest.mark.unit
66+
class TestCalibrationOffsetNormForward:
67+
"""Test that forward produces the same result as the original."""
68+
69+
def test_output_matches_original(self):
70+
original = FakeGemmaRMSNorm(dim=8, dtype=torch.float32)
71+
original.weight.data = torch.randn(8) * 0.1
72+
calib = CalibrationOffsetNorm(original, config=None)
73+
74+
x = torch.randn(2, 4, 8)
75+
original_out = original(x)
76+
calib_out = calib(x)
77+
78+
assert torch.allclose(original_out, calib_out, atol=1e-5)
79+
80+
81+
@pytest.mark.unit
82+
class TestCalibrationOffsetNormRestore:
83+
"""Test that restore reconverts weights correctly."""
84+
85+
def test_restore_roundtrip(self):
86+
original = FakeGemmaRMSNorm(dim=4, dtype=torch.bfloat16)
87+
original.weight.data = torch.tensor(
88+
[0.1, -0.05, 0.0, 0.2], dtype=torch.bfloat16
89+
)
90+
saved = original.weight.data.clone()
91+
92+
calib = CalibrationOffsetNorm(original, config=None)
93+
calib.restore(original)
94+
95+
assert original.weight.dtype == torch.bfloat16
96+
assert torch.allclose(original.weight.data.float(), saved.float(), atol=2e-2)
97+
98+
def test_restore_after_smoothing(self):
99+
original = FakeGemmaRMSNorm(dim=4, dtype=torch.float32)
100+
original.weight.data = torch.tensor([0.1, -0.05, 0.0, 0.2])
101+
102+
calib = CalibrationOffsetNorm(original, config=None)
103+
# Simulate a modifier dividing weights by scales=2
104+
calib.weight.data.div_(2.0)
105+
calib.restore(original)
106+
107+
# Standard weight after smoothing: [1.1, 0.95, 1.0, 1.2] / 2
108+
# = [0.55, 0.475, 0.5, 0.6]
109+
# Restored offset weight: standard - 1
110+
# = [-0.45, -0.525, -0.5, -0.4]
111+
expected = torch.tensor([-0.45, -0.525, -0.5, -0.4])
112+
assert torch.allclose(original.weight.data, expected, atol=1e-5)
113+
114+
# Verify: 1 + restored_weight == smoothed standard weight
115+
effective = 1.0 + original.weight.data
116+
expected_effective = torch.tensor([0.55, 0.475, 0.5, 0.6])
117+
assert torch.allclose(effective, expected_effective, atol=1e-5)
118+
119+
120+
@pytest.mark.unit
121+
class TestNormRegistration:
122+
"""Test that registered norms are detected and standard norms are not."""
123+
124+
def test_gemma_detected(self):
125+
"""GemmaRMSNorm (and aliases) should be in the registry."""
126+
names = NormCalibrationModule.registered_names()
127+
aliases = NormCalibrationModule.registered_aliases()
128+
all_registered = names + aliases
129+
for name in [
130+
"gemmarmsnorm",
131+
"gemma2rmsnorm",
132+
"gemma3rmsnorm",
133+
"qwen3nextrmsnorm",
134+
]:
135+
assert name in all_registered, f"{name} not in registry"
136+
137+
def test_standard_norm_not_detected(self):
138+
"""Standard LayerNorm should not be in the registry."""
139+
registered = NormCalibrationModule.registered_names()
140+
assert "layernorm" not in registered
141+
assert "rmsnorm" not in registered
142+
143+
144+
@pytest.mark.unit
145+
class TestNormCalibrationContext:
146+
"""Test that norm_calibration_context replaces and restores modules."""
147+
148+
def test_modules_replaced_inside_context(self):
149+
"""Offset norms should be replaced with CalibrationOffsetNorm inside."""
150+
layer = nn.Module()
151+
layer.input_layernorm = FakeGemmaRMSNorm(dim=8, dtype=torch.float32)
152+
layer.post_attention_layernorm = FakeGemmaRMSNorm(dim=8, dtype=torch.float32)
153+
154+
model = nn.Module()
155+
model.layer = layer
156+
model.config = SimpleNamespace(hidden_size=8)
157+
158+
with norm_calibration_context(model):
159+
assert isinstance(layer.input_layernorm, CalibrationOffsetNorm)
160+
assert isinstance(layer.post_attention_layernorm, CalibrationOffsetNorm)
161+
162+
def test_modules_restored_after_context(self):
163+
"""Original modules should be restored with correct weights."""
164+
layer = nn.Module()
165+
layer.input_layernorm = FakeGemmaRMSNorm(dim=4, dtype=torch.bfloat16)
166+
layer.input_layernorm.weight.data = torch.tensor(
167+
[0.1, -0.05, 0.0, 0.2], dtype=torch.bfloat16
168+
)
169+
saved = layer.input_layernorm.weight.data.clone()
170+
171+
model = nn.Module()
172+
model.layer = layer
173+
model.config = SimpleNamespace(hidden_size=4)
174+
175+
with norm_calibration_context(model):
176+
pass
177+
178+
assert isinstance(layer.input_layernorm, FakeGemmaRMSNorm)
179+
assert layer.input_layernorm.weight.dtype == torch.bfloat16
180+
assert torch.allclose(
181+
layer.input_layernorm.weight.data.float(), saved.float(), atol=2e-2
182+
)
183+
184+
def test_weights_updated_after_smoothing(self):
185+
"""Weights modified inside the context should be reflected after."""
186+
layer = nn.Module()
187+
layer.norm = FakeGemmaRMSNorm(dim=4, dtype=torch.float32)
188+
layer.norm.weight.data = torch.tensor([0.1, -0.05, 0.0, 0.2])
189+
190+
model = nn.Module()
191+
model.layer = layer
192+
model.config = SimpleNamespace(hidden_size=4)
193+
194+
with norm_calibration_context(model):
195+
# Simulate modifier dividing weights by scales=2
196+
layer.norm.weight.data.div_(2.0)
197+
198+
# Standard weight was [1.1, 0.95, 1.0, 1.2] / 2 = [0.55, 0.475, 0.5, 0.6]
199+
# Restored offset weight: standard - 1 = [-0.45, -0.525, -0.5, -0.4]
200+
expected = torch.tensor([-0.45, -0.525, -0.5, -0.4])
201+
assert torch.allclose(layer.norm.weight.data, expected, atol=1e-5)

0 commit comments

Comments
 (0)