|
| 1 | +""" |
| 2 | +Calibration context for offset-norm layers. |
| 3 | +
|
| 4 | +Some architectures (Gemma, Qwen3Next) use an offset normalization pattern where |
| 5 | +the forward pass computes ``output * (1 + weight)`` instead of the standard |
| 6 | +``output * weight``. This breaks any modifier that smooths norm weights |
| 7 | +(AWQ, SmoothQuant, SpinQuant) because dividing a (1+weight) parameter by scales |
| 8 | +produces incorrect results. |
| 9 | +
|
| 10 | +This module provides the infrastructure to temporarily replace offset-norm |
| 11 | +modules with standard-norm equivalents during calibration, and restore the |
| 12 | +original convention after modifiers have run. |
| 13 | +
|
| 14 | +Key components: |
| 15 | +- NormCalibrationModule: Abstract base class for norm calibration modules |
| 16 | +- norm_calibration_context: Context manager that applies norm conversion |
| 17 | +""" |
| 18 | + |
| 19 | +import contextlib |
| 20 | +from abc import ABC |
| 21 | + |
| 22 | +import torch |
| 23 | +from compressed_tensors.registry import RegistryMixin, standardize_lookup_name |
| 24 | +from loguru import logger |
| 25 | +from transformers import PreTrainedModel |
| 26 | + |
| 27 | +__all__ = [ |
| 28 | + "NormCalibrationModule", |
| 29 | + "norm_calibration_context", |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +class NormCalibrationModule(ABC, torch.nn.Module, RegistryMixin): |
| 34 | + """ |
| 35 | + Abstract base class for norm calibration modules. |
| 36 | +
|
| 37 | + Calibration modules replace original norm modules during the calibration |
| 38 | + phase so that modifiers see standard ``output * weight`` semantics. |
| 39 | +
|
| 40 | + Subclasses must: |
| 41 | + 1. Implement ``__init__()`` with signature: ``(self, original, config)`` |
| 42 | + 2. Implement ``restore()`` to convert back to the original norm convention |
| 43 | + """ |
| 44 | + |
| 45 | + is_permanent: bool = False |
| 46 | + |
| 47 | + def restore(self, original: torch.nn.Module) -> torch.nn.Module: |
| 48 | + """ |
| 49 | + Restore the original module with updated weights. |
| 50 | +
|
| 51 | + Returns: |
| 52 | + The original module with weights converted back to offset convention |
| 53 | + """ |
| 54 | + raise NotImplementedError( |
| 55 | + f"{self.__class__.__name__} doesn't implement restore()" |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +@NormCalibrationModule.register( |
| 60 | + "GemmaRMSNorm", |
| 61 | + alias=["Gemma2RMSNorm", "Gemma3RMSNorm", "Qwen3NextRMSNorm"], |
| 62 | +) |
| 63 | +class CalibrationOffsetNorm(NormCalibrationModule): |
| 64 | + """ |
| 65 | + Replaces offset-norm modules (``output * (1 + weight)``) with standard-norm |
| 66 | + equivalents (``output * weight``) during calibration. |
| 67 | +
|
| 68 | + On enter: ``self.weight = 1 + original.weight`` |
| 69 | + On restore: ``original.weight = self.weight - 1`` |
| 70 | + """ |
| 71 | + |
| 72 | + is_permanent = False |
| 73 | + |
| 74 | + def __init__(self, original: torch.nn.Module, config): |
| 75 | + super().__init__() |
| 76 | + self.eps = original.eps |
| 77 | + self._orig_dtype = original.weight.dtype |
| 78 | + self.weight = torch.nn.Parameter(1.0 + original.weight.data.float()) |
| 79 | + |
| 80 | + def _norm(self, x: torch.Tensor) -> torch.Tensor: |
| 81 | + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| 82 | + |
| 83 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 84 | + output = self._norm(x.float()) |
| 85 | + output = output * self.weight.float() |
| 86 | + return output.type_as(x) |
| 87 | + |
| 88 | + def restore(self, original: torch.nn.Module) -> torch.nn.Module: |
| 89 | + original.weight.data = (self.weight.data - 1.0).to(self._orig_dtype) |
| 90 | + return original |
| 91 | + |
| 92 | + |
| 93 | +@contextlib.contextmanager |
| 94 | +def norm_calibration_context(model: PreTrainedModel): |
| 95 | + """ |
| 96 | + Context manager that converts offset-norm modules to standard-norm. |
| 97 | +
|
| 98 | + This scans all modules in the model and replaces any offset-norm modules |
| 99 | + (``output * (1 + weight)``) with standard-norm equivalents |
| 100 | + (``output * weight``). After the context exits, modules are restored |
| 101 | + to their original convention with updated weights. |
| 102 | +
|
| 103 | + The model is modified in-place, so the same model object should be used |
| 104 | + within the context. |
| 105 | +
|
| 106 | + Args: |
| 107 | + model: The model to apply norm conversion to (modified in-place) |
| 108 | +
|
| 109 | + Example: |
| 110 | + with norm_calibration_context(model): |
| 111 | + # Modifiers see standard norm weights |
| 112 | + run_calibration(model) |
| 113 | + # Norms restored to offset convention with smoothed weights |
| 114 | + """ |
| 115 | + |
| 116 | + replaced = {} |
| 117 | + |
| 118 | + # Step 1: Collect all offset-norm modules that need replacement |
| 119 | + logger.debug("Entering norm calibration context") |
| 120 | + modules_to_replace = [] |
| 121 | + hidden_size = getattr(model.config, "hidden_size", None) |
| 122 | + for name, module in model.named_modules(): |
| 123 | + class_name = module.__class__.__name__ |
| 124 | + if _is_registered(class_name, NormCalibrationModule): |
| 125 | + # Only convert norms operating on hidden_size (e.g. input_layernorm, |
| 126 | + # post_attention_layernorm). Skip q_norm/k_norm which operate on |
| 127 | + # head_dim — they are not smoothed by any modifier. |
| 128 | + if hidden_size is not None and module.weight.shape[0] != hidden_size: |
| 129 | + continue |
| 130 | + modules_to_replace.append((name, module, class_name)) |
| 131 | + |
| 132 | + # Step 2: Replace modules |
| 133 | + if modules_to_replace: |
| 134 | + logger.info(f"Found {len(modules_to_replace)} offset-norm modules to convert") |
| 135 | + for name, module, class_name in modules_to_replace: |
| 136 | + replacement = NormCalibrationModule.load_from_registry( |
| 137 | + class_name, |
| 138 | + original=module, |
| 139 | + config=model.config, |
| 140 | + ) |
| 141 | + model.set_submodule(name, replacement) |
| 142 | + replaced[name] = (module, replacement) |
| 143 | + |
| 144 | + try: |
| 145 | + yield |
| 146 | + finally: |
| 147 | + # Step 3: Restore original modules with updated weights |
| 148 | + if replaced: |
| 149 | + logger.info(f"Restoring {len(replaced)} norm modules to offset convention") |
| 150 | + for name, (original, replacement) in replaced.items(): |
| 151 | + restored = replacement.restore(original) |
| 152 | + model.set_submodule(name, restored) |
| 153 | + |
| 154 | + |
| 155 | +def _is_registered(name: str, subclass: RegistryMixin): |
| 156 | + lookup = standardize_lookup_name(name) |
| 157 | + return ( |
| 158 | + lookup in subclass.registered_names() or lookup in subclass.registered_aliases() |
| 159 | + ) |
0 commit comments