Skip to content

[Transform] Spinquant with R1 and R2 #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Aug 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ba617db
wip
kylesayrs Jun 6, 2025
2f5b1c8
use random-hadamard, add correctness tests
kylesayrs Jun 12, 2025
3aa35e7
add correctness test, note that precision makes a large difference
kylesayrs Jun 12, 2025
b6c088e
add on lifecycle methods
brian-dellabetta Jun 23, 2025
d1eb2a1
Merge branch 'main' into kylesayrs/transform-modifier
brian-dellabetta Jul 1, 2025
3207124
TransformModifier with SpinQuant R1&R2
brian-dellabetta Jul 2, 2025
a88ca3c
spinquant and quip_online, running but outputting gibberish
brian-dellabetta Jul 2, 2025
5bd51df
updated example
brian-dellabetta Jul 2, 2025
3c216dd
DummyModel script
brian-dellabetta Jul 8, 2025
bbcdc8c
implement fuse_norm_linears
kylesayrs Jul 10, 2025
bd7f4d5
Merge branch 'kylesayrs/fuse-helpers' into bdellabe/transform-modifier
kylesayrs Jul 10, 2025
f5c2150
R1 working
kylesayrs Jul 11, 2025
dc5c30c
add r2, increase precision
kylesayrs Jul 11, 2025
7172c26
spinquant modifier
kylesayrs Jul 11, 2025
9298e82
remove space
kylesayrs Jul 11, 2025
f77226d
use iterable
kylesayrs Jul 11, 2025
fdb64b5
add rotation validation
kylesayrs Jul 11, 2025
5daa2d5
embedding fusion
kylesayrs Jul 11, 2025
0e9af7b
add missing norm fusion
kylesayrs Jul 12, 2025
fce83be
use norm mappings
kylesayrs Jul 12, 2025
a979f8a
break into separate files
kylesayrs Jul 12, 2025
4cab29e
small cleanup
kylesayrs Jul 12, 2025
f1cc987
cleanup
kylesayrs Jul 14, 2025
a7bb2e2
more cleanup
kylesayrs Jul 14, 2025
0cf0188
make new weight on cpu
kylesayrs Jul 14, 2025
53ea307
standardize, make modifier serializable
kylesayrs Jul 14, 2025
4b4257f
add compress model script
kylesayrs Jul 14, 2025
dc7ac1a
use untie_word_embeddings
kylesayrs Jul 15, 2025
8542f8d
style
kylesayrs Jul 15, 2025
b1e637e
better registery logic
kylesayrs Jul 15, 2025
b44ac81
remove dummy model test (add later)
kylesayrs Jul 15, 2025
7a52b71
docstring
kylesayrs Jul 15, 2025
f4d7ec6
update docstring
kylesayrs Jul 15, 2025
f18d0e8
rename example file
kylesayrs Jul 15, 2025
cec2914
use match_modules_set
kylesayrs Jul 16, 2025
f6c797e
Merge branch 'main' into bdellabe/transform-modifier
brian-dellabetta Jul 16, 2025
0c5c514
unit test fixes
brian-dellabetta Jul 17, 2025
f2ef7cf
style fixes
brian-dellabetta Jul 17, 2025
d0e5bc5
remove hardcoded pipeline logic
brian-dellabetta Jul 24, 2025
31ac8e9
docstrings
brian-dellabetta Jul 24, 2025
a4abb3d
stylefixes
brian-dellabetta Jul 24, 2025
63018c6
Merge branch 'main' into bdellabe/transform-modifier
brian-dellabetta Aug 7, 2025
4dcaeaa
add precision, add tests
kylesayrs Aug 8, 2025
88acb8d
remove print statement
kylesayrs Aug 8, 2025
98da74e
reduce diff
kylesayrs Aug 8, 2025
4705dbd
remove unused file
kylesayrs Aug 8, 2025
79fb6a8
fix typo
kylesayrs Aug 8, 2025
d95b1cc
use datafree pipeline in example
kylesayrs Aug 8, 2025
227f8e5
add messages to NotImplementedErrors
brian-dellabetta Aug 11, 2025
c32ca37
Merge branch 'main' into bdellabe/transform-modifier
brian-dellabetta Aug 12, 2025
b857da3
Merge branch 'main' into bdellabe/transform-modifier
brian-dellabetta Aug 12, 2025
d861bbb
minor touchups
brian-dellabetta Aug 12, 2025
062d908
more touchups
brian-dellabetta Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/transform/spinquant_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# NOTE: currently only fused rotations (R1 & R2) are available
# Learned rotations and online rotations (R3 & R4) will be added
# in a future release.
# Configure the quantization algorithm to run.
# * apply spinquant transforms to model to reduce quantization loss
# * quantize the weights to 4 bit with group size 128
recipe = [
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(model=model, recipe=recipe, pipeline="datafree")

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2-w4a16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
1 change: 1 addition & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa

from .fuse import *
from .prepare import *
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .spinquant import SpinQuantModifier
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
246 changes: 246 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from enum import Enum
from typing import Iterable, List, Literal, Optional

import torch
from compressed_tensors import match_modules_set, match_named_modules
from compressed_tensors.transform import (
TransformArgs,
TransformConfig,
TransformScheme,
apply_transform_config,
)
from compressed_tensors.utils import TorchDtype
from pydantic import Field, ValidationInfo, field_validator
from transformers import PreTrainedModel

from llmcompressor.core import Event, EventType, State
from llmcompressor.modeling import center_embeddings, fuse_norm_linears
from llmcompressor.modifiers import Modifier

from .mappings import SpinQuantMapping, infer_mapping_from_model
from .norm_mappings import NormMapping, infer_norm_mapping_from_model


class SpinquantRotation(str, Enum):
R1 = "R1"
R2 = "R2"
R3 = "R3"
R4 = "R4"


class SpinQuantModifier(Modifier, use_enum_values=True):
"""
Implements the transforms according to "SpinQuant: LLM quantization
with learned rotations" (https://arxiv.org/abs/2405.16406)
Transforms (rotations) are extra layers added to a model which reduce the accuracy
loss induced by quantization. This is achived through "rotating" weights and
activations into a space with a smaller dynamic range of values, thus decreasing
the range of scales required for quantization.
The SpinQuant authors describe four different rotations which can be applied to a
model. R1 and R2 are "offline" rotations, meaning that they can be fused into
existing weights and therefore do not induce runtime cost. R3 and R4 are "online"
rotations, meaning that they require additional computation at runtime.
Lifecycle:
- on_initialize
- infer SpinQuantMappings & NormMappings
- as needed, create transform schemes for R1, R2, R3, & R4
- on_start
- normalize embeddings
- fuse norm layers into subsequent Linear layers
- apply TransformConfig
- fuse transforms into weights for mergeable transforms
- add hooks for online transforms
- on sequential epoch end
- on_end
- on_finalize
:param rotations: A list containing the names of rotations to apply to the model.
Possible rotations include R1, R2, R3, and R4
:param transform_type: The type of transform to apply to the model.
`"hadamard"` has the least performance cost but only supports sizes which are
powers of power of two.
`"random-matrix"` has more performance cost, but supports a much larger set of
sizes.
`"random-matrix"` has the greatest performance cost, but supports any size
:param randomize: if True, create distinct transforms for each application
:param learnable: if True, attach gradients to transform weights for training
:param precision: Precision at which all transforms should be applied. This applies
to both weight fusing and online rotations
:param mappings: Specifies layers within a model to target for transforms.
A mapping will be inferred if None is provided
:param norm_mappings: Specifies layers within a model to target for norm fusing.
A mapping will be inferred if None is provided
:param transform_config: Optional transform config for overriding provided arguments
"""

rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"])
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
default="hadamard"
)
randomize: bool = Field(default=False)
learnable: bool = Field(default=False)
precision: TorchDtype = Field(default=torch.float64)

# norm mappings separate from spinquant mappings to allow users to
# override spinquant mappings with transform_config without overriding norms
mappings: Optional[SpinQuantMapping] = Field(
default=None,
repr=False,
exclude=True,
)
norm_mappings: Optional[List[NormMapping]] = Field(
default=None,
repr=False,
exclude=True,
)

# optional override for more fine-grained control
# also included in recipe serialization
transform_config: Optional[TransformConfig] = Field(default=None, repr=False)

@field_validator("randomize", "learnable", mode="before")
def validate_not_implemented(cls, value, info: ValidationInfo):
if value:
raise NotImplementedError(f"{info.field_name} is not supported right now")
return value

@field_validator("rotations", mode="before")
def validate_rotations(cls, value):
if isinstance(value, Iterable):
return tuple(v.upper() for v in value)
return value

def on_initialize(self, state: State, **kwargs) -> bool:
if self.transform_config is not None:
return True

self.mappings = infer_mapping_from_model(state.model)
self.norm_mappings = infer_norm_mapping_from_model(state.model)

config_groups = {}
if SpinquantRotation.R1 in self.rotations:
config_groups["R1"] = self._create_r1_scheme()

if SpinquantRotation.R2 in self.rotations:
config_groups["R2"] = self._create_r2_scheme(state.model)

if SpinquantRotation.R3 in self.rotations:
config_groups["R3"] = self._create_r3_scheme()

if SpinquantRotation.R4 in self.rotations:
config_groups["R4"] = self._create_r4_scheme()

self.transform_config = TransformConfig(config_groups=config_groups)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# needs to happen after the model has been hooked to execute on the GPU
# otherwise we're applying weight transforms on CPU
self._center_embeddings(state.model)
self._fuse_norms(state.model)
apply_transform_config(state.model, self.transform_config)

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
pass

elif event.type_ == EventType.CALIBRATION_EPOCH_END:
if not self.ended_:
self.on_end(state, None)

def on_end(self, state: State, event: Event, **kwargs):
self.ended_ = True

def on_finalize(self, state: State, **kwargs) -> bool:
if not self.ended_:
self.on_end(state, None)

return True

def _center_embeddings(self, model: PreTrainedModel):
for _, embedding in match_named_modules(
model, [self.mappings.embedding], warn_on_fail=True
):
center_embeddings(embedding)

def _fuse_norms(self, model: PreTrainedModel):
for mapping in self.norm_mappings:
for norm, *linears in match_modules_set(
model, (mapping.norm, *mapping.linears)
):
fuse_norm_linears(norm, linears)

def _create_r1_scheme(self) -> TransformScheme:
return TransformScheme(
type=self.transform_type,
randomize=self.randomize,
requires_grad=self.learnable,
precision=self.precision,
apply=[
TransformArgs(
targets=[
self.mappings.embedding,
self.mappings.attn_o,
*self.mappings.mlp_out,
],
location="weight_output",
),
TransformArgs(
targets=[
self.mappings.attn_q,
self.mappings.attn_k,
self.mappings.attn_v,
*self.mappings.mlp_in,
self.mappings.lm_head,
],
location="weight_input",
inverse=True,
),
],
)

def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
config = model.config

if hasattr(config, "head_dim"):
head_dim = config.head_dim
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
head_dim = config.hidden_size // config.num_attention_heads
else:
raise NotImplementedError()

return TransformScheme(
type=self.transform_type,
randomize=self.randomize,
requires_grad=self.learnable,
precision=self.precision,
head_dim=head_dim,
apply=[
TransformArgs(targets=[self.mappings.attn_v], location="weight_output"),
TransformArgs(
targets=[self.mappings.attn_o],
location="weight_input",
inverse=True,
),
],
)

def _create_r3_scheme(self) -> TransformScheme:
raise NotImplementedError(
"SpinQuant R3 and R4 rotations will be added in a future release"
)

def _create_r4_scheme(self) -> TransformScheme:
raise NotImplementedError(
"SpinQuant R3 and R4 rotations will be added in a future release"
)
76 changes: 76 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/mappings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Dict, List, Optional

from loguru import logger
from pydantic import BaseModel, Field, field_validator
from transformers import PreTrainedModel

__all__ = ["SpinQuantMapping", "infer_mapping_from_model"]


class SpinQuantMapping(BaseModel):
"""
SpinQuant needs to know the entire architecture of the model,
as R1, R2, R3, and R4 rotations need to be applied to specific
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).

:param embedding: name or regex of embedding layer
:param attn_q: name or regex of q_proj layer in attention block
:param attn_k: name or regex of k_proj layer in attention block
:param attn_v: name or regex of v_proj layer in attention block
:param attn_o: name or regex of o_proj layer in attention block
:param attn_head_dim: head_dim of the attention module, needed
because R2 needs to be applied "head-wisely" to v_proj and
o_proj
:param mlp_in: list of names or regexes for the mlp blocks that
receive the input to the MLP block, usually up_proj and gate_proj
:param mlp_out: list of names or regexes for the mlp blocks that
consitute the output of the MLP block, usually down_proj
"""

embedding: str

attn_q: str
attn_k: str
attn_v: str
attn_o: str
attn_head_dim: Optional[int] = Field(default=None)

mlp_in: List[str] # up_proj, gate_proj
mlp_out: List[str] # down_proj

lm_head: str

@field_validator("mlp_in", "mlp_out", mode="before")
def cast_to_list(cls, value):
if isinstance(value, str):
return [value]

return value


_default_mappings = SpinQuantMapping(
embedding="re:.*embed_tokens$",
attn_q="re:.*q_proj$",
attn_k="re:.*k_proj$",
attn_v="re:.*v_proj$",
attn_o="re:.*o_proj$",
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
mlp_out="re:.*down_proj$",
lm_head="lm_head",
)


SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = {
"LlamaForCausalLM": _default_mappings,
}


def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping:
architecture = model.__class__.__name__
if architecture not in SPINQUANT_MAPPING_REGISTRY:
logger.info(
f"Unrecognized model architecture {architecture}. "
"Falling back to default mappings"
)

return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings)
Loading