|
| 1 | +from enum import Enum |
| 2 | +from typing import Iterable, List, Literal, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from compressed_tensors import match_modules_set, match_named_modules |
| 6 | +from compressed_tensors.transform import ( |
| 7 | + TransformArgs, |
| 8 | + TransformConfig, |
| 9 | + TransformScheme, |
| 10 | + apply_transform_config, |
| 11 | +) |
| 12 | +from compressed_tensors.utils import TorchDtype |
| 13 | +from pydantic import Field, ValidationInfo, field_validator |
| 14 | +from transformers import PreTrainedModel |
| 15 | + |
| 16 | +from llmcompressor.core import Event, EventType, State |
| 17 | +from llmcompressor.modeling import center_embeddings, fuse_norm_linears |
| 18 | +from llmcompressor.modifiers import Modifier |
| 19 | + |
| 20 | +from .mappings import SpinQuantMapping, infer_mapping_from_model |
| 21 | +from .norm_mappings import NormMapping, infer_norm_mapping_from_model |
| 22 | + |
| 23 | + |
| 24 | +class SpinquantRotation(str, Enum): |
| 25 | + R1 = "R1" |
| 26 | + R2 = "R2" |
| 27 | + R3 = "R3" |
| 28 | + R4 = "R4" |
| 29 | + |
| 30 | + |
| 31 | +class SpinQuantModifier(Modifier, use_enum_values=True): |
| 32 | + """ |
| 33 | + Implements the transforms according to "SpinQuant: LLM quantization |
| 34 | + with learned rotations" (https://arxiv.org/abs/2405.16406) |
| 35 | +
|
| 36 | + Transforms (rotations) are extra layers added to a model which reduce the accuracy |
| 37 | + loss induced by quantization. This is achived through "rotating" weights and |
| 38 | + activations into a space with a smaller dynamic range of values, thus decreasing |
| 39 | + the range of scales required for quantization. |
| 40 | +
|
| 41 | + The SpinQuant authors describe four different rotations which can be applied to a |
| 42 | + model. R1 and R2 are "offline" rotations, meaning that they can be fused into |
| 43 | + existing weights and therefore do not induce runtime cost. R3 and R4 are "online" |
| 44 | + rotations, meaning that they require additional computation at runtime. |
| 45 | +
|
| 46 | + Lifecycle: |
| 47 | + - on_initialize |
| 48 | + - infer SpinQuantMappings & NormMappings |
| 49 | + - as needed, create transform schemes for R1, R2, R3, & R4 |
| 50 | + - on_start |
| 51 | + - normalize embeddings |
| 52 | + - fuse norm layers into subsequent Linear layers |
| 53 | + - apply TransformConfig |
| 54 | + - fuse transforms into weights for mergeable transforms |
| 55 | + - add hooks for online transforms |
| 56 | + - on sequential epoch end |
| 57 | + - on_end |
| 58 | + - on_finalize |
| 59 | +
|
| 60 | + :param rotations: A list containing the names of rotations to apply to the model. |
| 61 | + Possible rotations include R1, R2, R3, and R4 |
| 62 | + :param transform_type: The type of transform to apply to the model. |
| 63 | + `"hadamard"` has the least performance cost but only supports sizes which are |
| 64 | + powers of power of two. |
| 65 | + `"random-matrix"` has more performance cost, but supports a much larger set of |
| 66 | + sizes. |
| 67 | + `"random-matrix"` has the greatest performance cost, but supports any size |
| 68 | + :param randomize: if True, create distinct transforms for each application |
| 69 | + :param learnable: if True, attach gradients to transform weights for training |
| 70 | + :param precision: Precision at which all transforms should be applied. This applies |
| 71 | + to both weight fusing and online rotations |
| 72 | + :param mappings: Specifies layers within a model to target for transforms. |
| 73 | + A mapping will be inferred if None is provided |
| 74 | + :param norm_mappings: Specifies layers within a model to target for norm fusing. |
| 75 | + A mapping will be inferred if None is provided |
| 76 | + :param transform_config: Optional transform config for overriding provided arguments |
| 77 | + """ |
| 78 | + |
| 79 | + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) |
| 80 | + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( |
| 81 | + default="hadamard" |
| 82 | + ) |
| 83 | + randomize: bool = Field(default=False) |
| 84 | + learnable: bool = Field(default=False) |
| 85 | + precision: TorchDtype = Field(default=torch.float64) |
| 86 | + |
| 87 | + # norm mappings separate from spinquant mappings to allow users to |
| 88 | + # override spinquant mappings with transform_config without overriding norms |
| 89 | + mappings: Optional[SpinQuantMapping] = Field( |
| 90 | + default=None, |
| 91 | + repr=False, |
| 92 | + exclude=True, |
| 93 | + ) |
| 94 | + norm_mappings: Optional[List[NormMapping]] = Field( |
| 95 | + default=None, |
| 96 | + repr=False, |
| 97 | + exclude=True, |
| 98 | + ) |
| 99 | + |
| 100 | + # optional override for more fine-grained control |
| 101 | + # also included in recipe serialization |
| 102 | + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) |
| 103 | + |
| 104 | + @field_validator("randomize", "learnable", mode="before") |
| 105 | + def validate_not_implemented(cls, value, info: ValidationInfo): |
| 106 | + if value: |
| 107 | + raise NotImplementedError(f"{info.field_name} is not supported right now") |
| 108 | + return value |
| 109 | + |
| 110 | + @field_validator("rotations", mode="before") |
| 111 | + def validate_rotations(cls, value): |
| 112 | + if isinstance(value, Iterable): |
| 113 | + return tuple(v.upper() for v in value) |
| 114 | + return value |
| 115 | + |
| 116 | + def on_initialize(self, state: State, **kwargs) -> bool: |
| 117 | + if self.transform_config is not None: |
| 118 | + return True |
| 119 | + |
| 120 | + self.mappings = infer_mapping_from_model(state.model) |
| 121 | + self.norm_mappings = infer_norm_mapping_from_model(state.model) |
| 122 | + |
| 123 | + config_groups = {} |
| 124 | + if SpinquantRotation.R1 in self.rotations: |
| 125 | + config_groups["R1"] = self._create_r1_scheme() |
| 126 | + |
| 127 | + if SpinquantRotation.R2 in self.rotations: |
| 128 | + config_groups["R2"] = self._create_r2_scheme(state.model) |
| 129 | + |
| 130 | + if SpinquantRotation.R3 in self.rotations: |
| 131 | + config_groups["R3"] = self._create_r3_scheme() |
| 132 | + |
| 133 | + if SpinquantRotation.R4 in self.rotations: |
| 134 | + config_groups["R4"] = self._create_r4_scheme() |
| 135 | + |
| 136 | + self.transform_config = TransformConfig(config_groups=config_groups) |
| 137 | + |
| 138 | + return True |
| 139 | + |
| 140 | + def on_start(self, state: State, event: Event, **kwargs): |
| 141 | + self.started_ = True |
| 142 | + |
| 143 | + # needs to happen after the model has been hooked to execute on the GPU |
| 144 | + # otherwise we're applying weight transforms on CPU |
| 145 | + self._center_embeddings(state.model) |
| 146 | + self._fuse_norms(state.model) |
| 147 | + apply_transform_config(state.model, self.transform_config) |
| 148 | + |
| 149 | + def on_event(self, state: State, event: Event, **kwargs): |
| 150 | + if event.type_ == EventType.CALIBRATION_EPOCH_START: |
| 151 | + if not self.started_: |
| 152 | + self.on_start(state, None) |
| 153 | + |
| 154 | + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: |
| 155 | + pass |
| 156 | + |
| 157 | + elif event.type_ == EventType.CALIBRATION_EPOCH_END: |
| 158 | + if not self.ended_: |
| 159 | + self.on_end(state, None) |
| 160 | + |
| 161 | + def on_end(self, state: State, event: Event, **kwargs): |
| 162 | + self.ended_ = True |
| 163 | + |
| 164 | + def on_finalize(self, state: State, **kwargs) -> bool: |
| 165 | + if not self.ended_: |
| 166 | + self.on_end(state, None) |
| 167 | + |
| 168 | + return True |
| 169 | + |
| 170 | + def _center_embeddings(self, model: PreTrainedModel): |
| 171 | + for _, embedding in match_named_modules( |
| 172 | + model, [self.mappings.embedding], warn_on_fail=True |
| 173 | + ): |
| 174 | + center_embeddings(embedding) |
| 175 | + |
| 176 | + def _fuse_norms(self, model: PreTrainedModel): |
| 177 | + for mapping in self.norm_mappings: |
| 178 | + for norm, *linears in match_modules_set( |
| 179 | + model, (mapping.norm, *mapping.linears) |
| 180 | + ): |
| 181 | + fuse_norm_linears(norm, linears) |
| 182 | + |
| 183 | + def _create_r1_scheme(self) -> TransformScheme: |
| 184 | + return TransformScheme( |
| 185 | + type=self.transform_type, |
| 186 | + randomize=self.randomize, |
| 187 | + requires_grad=self.learnable, |
| 188 | + precision=self.precision, |
| 189 | + apply=[ |
| 190 | + TransformArgs( |
| 191 | + targets=[ |
| 192 | + self.mappings.embedding, |
| 193 | + self.mappings.attn_o, |
| 194 | + *self.mappings.mlp_out, |
| 195 | + ], |
| 196 | + location="weight_output", |
| 197 | + ), |
| 198 | + TransformArgs( |
| 199 | + targets=[ |
| 200 | + self.mappings.attn_q, |
| 201 | + self.mappings.attn_k, |
| 202 | + self.mappings.attn_v, |
| 203 | + *self.mappings.mlp_in, |
| 204 | + self.mappings.lm_head, |
| 205 | + ], |
| 206 | + location="weight_input", |
| 207 | + inverse=True, |
| 208 | + ), |
| 209 | + ], |
| 210 | + ) |
| 211 | + |
| 212 | + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: |
| 213 | + config = model.config |
| 214 | + |
| 215 | + if hasattr(config, "head_dim"): |
| 216 | + head_dim = config.head_dim |
| 217 | + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): |
| 218 | + head_dim = config.hidden_size // config.num_attention_heads |
| 219 | + else: |
| 220 | + raise NotImplementedError() |
| 221 | + |
| 222 | + return TransformScheme( |
| 223 | + type=self.transform_type, |
| 224 | + randomize=self.randomize, |
| 225 | + requires_grad=self.learnable, |
| 226 | + precision=self.precision, |
| 227 | + head_dim=head_dim, |
| 228 | + apply=[ |
| 229 | + TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), |
| 230 | + TransformArgs( |
| 231 | + targets=[self.mappings.attn_o], |
| 232 | + location="weight_input", |
| 233 | + inverse=True, |
| 234 | + ), |
| 235 | + ], |
| 236 | + ) |
| 237 | + |
| 238 | + def _create_r3_scheme(self) -> TransformScheme: |
| 239 | + raise NotImplementedError( |
| 240 | + "SpinQuant R3 and R4 rotations will be added in a future release" |
| 241 | + ) |
| 242 | + |
| 243 | + def _create_r4_scheme(self) -> TransformScheme: |
| 244 | + raise NotImplementedError( |
| 245 | + "SpinQuant R3 and R4 rotations will be added in a future release" |
| 246 | + ) |
0 commit comments