Skip to content

[Transform] Online Rotations #1651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ba617db
wip
kylesayrs Jun 6, 2025
2f5b1c8
use random-hadamard, add correctness tests
kylesayrs Jun 12, 2025
3aa35e7
add correctness test, note that precision makes a large difference
kylesayrs Jun 12, 2025
b6c088e
add on lifecycle methods
brian-dellabetta Jun 23, 2025
d1eb2a1
Merge branch 'main' into kylesayrs/transform-modifier
brian-dellabetta Jul 1, 2025
3207124
TransformModifier with SpinQuant R1&R2
brian-dellabetta Jul 2, 2025
a88ca3c
spinquant and quip_online, running but outputting gibberish
brian-dellabetta Jul 2, 2025
5bd51df
updated example
brian-dellabetta Jul 2, 2025
3c216dd
DummyModel script
brian-dellabetta Jul 8, 2025
bbcdc8c
implement fuse_norm_linears
kylesayrs Jul 10, 2025
bd7f4d5
Merge branch 'kylesayrs/fuse-helpers' into bdellabe/transform-modifier
kylesayrs Jul 10, 2025
f5c2150
R1 working
kylesayrs Jul 11, 2025
dc5c30c
add r2, increase precision
kylesayrs Jul 11, 2025
7172c26
spinquant modifier
kylesayrs Jul 11, 2025
9298e82
remove space
kylesayrs Jul 11, 2025
f77226d
use iterable
kylesayrs Jul 11, 2025
fdb64b5
add rotation validation
kylesayrs Jul 11, 2025
5daa2d5
embedding fusion
kylesayrs Jul 11, 2025
0e9af7b
add missing norm fusion
kylesayrs Jul 12, 2025
fce83be
use norm mappings
kylesayrs Jul 12, 2025
a979f8a
break into separate files
kylesayrs Jul 12, 2025
4cab29e
small cleanup
kylesayrs Jul 12, 2025
f1cc987
cleanup
kylesayrs Jul 14, 2025
a7bb2e2
more cleanup
kylesayrs Jul 14, 2025
0cf0188
make new weight on cpu
kylesayrs Jul 14, 2025
53ea307
standardize, make modifier serializable
kylesayrs Jul 14, 2025
4b4257f
add compress model script
kylesayrs Jul 14, 2025
dc7ac1a
use untie_word_embeddings
kylesayrs Jul 15, 2025
8542f8d
style
kylesayrs Jul 15, 2025
b1e637e
better registery logic
kylesayrs Jul 15, 2025
b44ac81
remove dummy model test (add later)
kylesayrs Jul 15, 2025
7a52b71
docstring
kylesayrs Jul 15, 2025
f4d7ec6
update docstring
kylesayrs Jul 15, 2025
f18d0e8
rename example file
kylesayrs Jul 15, 2025
cec2914
use match_modules_set
kylesayrs Jul 16, 2025
0a146f8
hook with CompressedAttentionImpl
kylesayrs Jul 16, 2025
0e4e002
use qkv hooks
kylesayrs Jul 16, 2025
5aa3586
use get_compressed_attention_impl
kylesayrs Jul 16, 2025
a9b2f51
r3 r4 works, but not with sdpa
kylesayrs Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions compress_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.utils import dispatch_for_generation

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, help="Model stub to compress")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The --model_id argument is essential for this script to run. Without it, args.model_id will be None, causing AutoModelForCausalLM.from_pretrained to fail. Please make this argument required to prevent runtime errors and add a more descriptive help message.

Suggested change
parser.add_argument("--model_id", type=str, help="Model stub to compress")
parser.add_argument("--model_id", type=str, required=True, help="Hugging Face model ID to compress")

parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier")
parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)")
return parser.parse_args()

if __name__ == "__main__":
args = parse_args()

# Select model and load it.
MODEL_ID = args.model_id
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Configure the quantization algorithm to run.
recipe = []
if args.transform_type:
recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type))

if args.scheme:
recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"]))

# Apply algorithms.
oneshot(
model=model,
recipe=recipe,
dataset="ultrachat_200k",
splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"},
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current method of constructing SAVE_DIR will include "None" in the directory name if --transform_type or --scheme are not provided. This can lead to confusing directory names. It's better to build the save directory name conditionally, only including the parts that are actually provided.

Suggested change
SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}"
SAVE_DIR = "-".join([p for p in (MODEL_ID.split("/")[1], args.transform_type, args.scheme) if p])

model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
91 changes: 91 additions & 0 deletions examples/transform/spinquant_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.modifiers.transform import SpinQuantModifier
from llmcompressor.utils import dispatch_for_generation

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager")

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# * apply spinquant transforms to model in order to make quantization easier
# * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = [
SpinQuantModifier(
rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard"
),
# QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
]

# Apply algorithms.
oneshot(
model=model,
recipe=recipe,
dataset=ds,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
from llmcompressor.utils import calibration_forward_context

with calibration_forward_context(model):
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
7 changes: 3 additions & 4 deletions src/llmcompressor/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from llmcompressor.pytorch.model_load.helpers import parse_dtype
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
untie_word_embeddings,
)
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
Expand Down Expand Up @@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"):
)

# untie tie_word_embeddings weights
patch_tied_tensors_bug(model_args.model)
if not model_args.tie_word_embeddings:
untie_word_embeddings(model_args.model)

# wrap model.save_pretrained
modify_save_pretrained(model_args.model)
Expand Down Expand Up @@ -143,7 +144,6 @@ def initialize_model_from_path(
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)

Expand All @@ -156,7 +156,6 @@ def initialize_model_from_path(
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
Expand Down
1 change: 1 addition & 0 deletions src/llmcompressor/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa

from .fuse import *
from .prepare import *
58 changes: 58 additions & 0 deletions src/llmcompressor/modeling/fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Iterable

import torch
from compressed_tensors import (
align_module_device,
get_execution_device,
update_offload_parameter,
)
from transformers.models.llama.modeling_llama import LlamaRMSNorm

__all__ = ["normalize_embedding", "fuse_norm_linears"]


PRECISION = torch.float64


def normalize_embedding(embedding: torch.nn.Module):
if isinstance(embedding, (torch.nn.Embedding)):
with align_module_device(embedding):
weight_dtype = embedding.weight.dtype
weight = embedding.weight.to(PRECISION)
new_weight = weight - weight.mean(dim=-1, keepdim=True)
new_weight = new_weight.to(weight_dtype)

update_offload_parameter(embedding, "weight", new_weight)

else:
raise ValueError(f"Cannot normalize embedding of type {type(embedding)}")


def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]):
"""
Fuse a norm layer into subsequent linear layers. This useful for ensuring transform
invariance between norm and linear layers.

Note that a model cannot be properly trained after its norms have been fused

:param norm: norm layer whose weight will be fused into subsequent linears
:param linears: linear layers which directly follow the norm layer
"""
if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)):
for linear in linears:
# NOTE: spinquant does this op in float64
exec_device = get_execution_device(norm)
with align_module_device(norm, exec_device), align_module_device(
linear, exec_device
):
weight_dtype = linear.weight.dtype
new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION)
new_weight = new_weight.to(weight_dtype)

update_offload_parameter(linear, "weight", new_weight)

new_norm_weight = torch.ones_like(norm.weight, device="cpu")
update_offload_parameter(norm, "weight", new_norm_weight)

else:
raise ValueError(f"Cannot fuse norm of type {type(norm)}")
39 changes: 36 additions & 3 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Tuple
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple

import torch
from compressed_tensors.quantization import (
Expand All @@ -13,18 +14,26 @@
from compressed_tensors.utils import align_module_device, update_parameter_data
from loguru import logger
from torch.nn import Module
from torch.utils.hooks import RemovableHandle

from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache
from llmcompressor.observers import Observer
from llmcompressor.utils.helpers import getattr_chain

if TYPE_CHECKING:
from compressed_tensors.modeling.attention import CompressedAttentionImpl

from llmcompressor.modifiers.utils.hooks import HooksMixin


DEFAULT_MAXSHRINK = 0.20
DEFAULT_PATIENCE = 5
DEFAULT_AVERAGING_CONSTANT = 0.01
DEFAULT_GRID = 100.0
DEFAULT_NORM = 2.4

__all__ = [
"register_calibrate_attn_hooks",
"initialize_observer",
"update_weight_zp_scale",
"calibrate_input_hook",
Expand Down Expand Up @@ -205,14 +214,30 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
)


def calibrate_input_hook(module: Module, args: Any):
def register_calibrate_attn_hooks(
modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl"
) -> Set[RemovableHandle]:
return {
modifier.register_hook(
attention_impl, partial(calibrate_input_hook, basename="q"), "query"
),
modifier.register_hook(
attention_impl, partial(calibrate_input_hook, basename="k"), "key"
),
modifier.register_hook(
attention_impl, partial(calibrate_input_hook, basename="v"), "value"
),
}


def calibrate_input_hook(module: Module, args: Any, base_name: str = "input"):
"""
Hook to calibrate input activations.
Will call the observers to update the scales/zp before applying
input QDQ in the module's forward pass.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")
calibrate_activations(module, value=args, base_name=base_name)


def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
Expand Down Expand Up @@ -282,6 +307,14 @@ def initialize_quantized_kv_cache(module: Module):
setattr(module, "kv_cache", quantized_kv_cache)


def initialize_attention_observers(module: Module):
input_args = getattr_chain(module, "quantization_scheme.input_activations", None)
if input_args is not None:
initialize_observer(module, "q", input_args)
initialize_observer(module, "k", input_args)
initialize_observer(module, "v", input_args)


def apply_calibration_status(module: Module):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module):
# kv_cache activations. Within `apply_quantization_config`, the config is
# modified to use attention output quantization if a kv_cache_scheme exists
if is_attention and output:
# initialize_attention_observers(module) # TODO: attnq
initialize_quantized_kv_cache(module)

# output activations
Expand All @@ -240,6 +241,7 @@ def _initialize_observers(self, module: torch.nn.Module):

def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
hooks = set()

for module in model.modules():
if not hasattr(module, "quantization_scheme"):
continue
Expand All @@ -258,6 +260,11 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
self.register_hook(module, calibrate_input_hook, "forward_pre")
)

# TODO: attnq
# if is_attention:
# attention_impl = CompressedAttentionImpl.from_module(module)
# hooks |= register_calibrate_attn_hooks(self, attention_impl)

# kv_cache activations. Within `apply_quantization_config`, the config is
# modified to use attention output quantization if a kv_cache_scheme exists
if is_attention and output:
Expand Down
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .spinquant import SpinQuantModifier
3 changes: 3 additions & 0 deletions src/llmcompressor/modifiers/transform/spinquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .base import *
Loading
Loading