Skip to content

Commit 8747bae

Browse files
[Transform] Spinquant with R1 and R2 (#1615)
## Purpose ## * Enable offline spinquant-style transforms ## Prerequisites ## * neuralmagic/compressed-tensors#370 * neuralmagic/compressed-tensors#412 * neuralmagic/compressed-tensors#414 ## Changes ## * Added `spinquant_example.py` to examples folder * Added `SpinQuantModifier` which handles the construction of a spinquant-style transform config ## Testing ## * Added modifier serialization and correctness tests ## Evaluation ## Using this branch, and [the original SpinQuant code](https://github.com/facebookresearch/SpinQuant), we see very similar results for `meta-llama/Llama-3.2-1B-Instruct` with W4A16 quantization. Results are equivalent in hf (in-memory vs serialized and re-loaded), and very similar in vllm. The symmetric scales calculation in `llm-compressor` is slightly different than original SpinQuant paper, which uses the original GPTQ implementation. When this is swapped in, results are consistent, with hadamard improving results on `gsm8k_llama` and `arc_challenge_llama`: Scheme | Impl | gsm8k | gsm8k_llama | arc_challenge_llama -- | -- | -- | -- | -- Hadamard+W4A16 | LC | 0.2403 | 0.2835 | 0.5262 W4A16 | LC | 0.1964 | 0.1933 | 0.4781 Hadamard+W4A16 | LC+SQscales | 0.1721 | 0.2183 | 0.485 W4A16 | LC+SQscales | 0.207 | 0.1706 | 0.4498 Hadamard+W4A16 | SQ | 0.1736 | 0.2282 | 0.4807 W4A16 | SQ | 0.1986 | 0.1774 | 0.4489 To run LC+SQScales, change [this line in CT](https://github.com/neuralmagic/compressed-tensors/blob/b2df366797b00330ec765f5891dde14e4cc74c9d/src/compressed_tensors/quantization/utils/helpers.py#L111) from ```python scales = max_val_pos / (float(bit_range) / 2) ``` to ```python scales = max_val_pos / (float(bit_max)) ``` <details> <summary>The following python script was used to generate these results</summary> Clone SpinQuant repo and paste this in the top-level directory: ```python # coding=utf-8 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from typing import Literal import os os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" from torch import nn import lm_eval from transformers import LlamaForCausalLM, AutoTokenizer import transformers from train_utils.main import prepare_model from train_utils.modeling_llama_quant import LlamaForCausalLM as LlamaForCausalLMQuant from utils.hadamard_utils import random_hadamard_matrix, hadamard_matrix from utils.process_args import process_args_ptq # model_id = "meta-llama/Llama-3.1-8B-Instruct" # model_id = "meta-llama/Llama-3.2-3B-Instruct" model_id = "meta-llama/Llama-3.2-1B-Instruct" dtype = torch.bfloat16 class RotateModule(nn.Module): def __init__(self, R_init): super(RotateModule, self).__init__() self.weight = nn.Parameter(R_init.to(torch.float32).to(torch.device("cuda"))) def forward(self, x, transpose=False): if transpose: return x @ self.weight else: return self.weight @ x def get_sq_model( r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16], w_clip: bool = False, ) -> LlamaForCausalLMQuant: model_args, training_args, ptq_args = process_args_ptq() model_args.input_model = model_id if w_bits == 4: ptq_args.w_bits = 4 ptq_args.w_groupsize = 128 ptq_args.w_rtn = True # if False, GPTQ is used ptq_args.w_clip = w_clip ptq_args.a_bits = 16 ptq_args.k_bits = 16 ptq_args.v_bits = 16 print("=======ARGS=======", ptq_args) config = transformers.AutoConfig.from_pretrained(model_args.input_model) # Llama v3.2 specific: Spinquant is not compatiable with tie_word_embeddings, clone lm_head from embed_tokens process_word_embeddings = False if config.tie_word_embeddings: config.tie_word_embeddings = False process_word_embeddings = True model = LlamaForCausalLMQuant.from_pretrained( pretrained_model_name_or_path=model_args.input_model, config=config, torch_dtype=dtype, device_map="cuda", ) if process_word_embeddings: model.lm_head.weight.data = model.model.embed_tokens.weight.data.clone() model = prepare_model(ptq_args, model) for param in model.parameters(): param.requires_grad = False match r1r2: case "eye": R1 = torch.eye(model.config.hidden_size, device="cuda") case "random-hadamard": R1 = random_hadamard_matrix(model.config.hidden_size, "cuda") case _: R1 = hadamard_matrix(model.config.hidden_size, "cuda") model.R1 = RotateModule(R1) for i in range(model.config.num_hidden_layers): # Each head dim = 128 for Llama model match r1r2: case "eye": R2 = torch.eye( model.config.hidden_size // model.config.num_attention_heads, device="cuda", ) case "random-hadamard": R2 = random_hadamard_matrix( model.config.hidden_size // model.config.num_attention_heads, "cuda" ) case _: R2 = hadamard_matrix( model.config.hidden_size // model.config.num_attention_heads, "cuda" ) model.model.layers[i].self_attn.R2 = RotateModule(R2) model.config.use_cache = False return model def get_lc_model( r1r2=Literal["eye", "random-hadamard", "hadamard"], w_bits=Literal[4, 16] ) -> LlamaForCausalLM: from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.transform import SpinQuantModifier model = LlamaForCausalLM.from_pretrained( pretrained_model_name_or_path=model_id, torch_dtype=dtype, device_map="cuda", ) recipe = [ SpinQuantModifier( rotations=[] if r1r2 == "eye" else ["R1", "R2"], transform_type="hadamard", ) ] if w_bits == 4: recipe.append( QuantizationModifier( targets="Linear", scheme="W4A16", ignore=["lm_head"], ) ) oneshot( model=model, recipe=recipe, pipeline="datafree", log_dir=None, ) return model if __name__ == "__main__": for scales_impl in ["sq_min_hack", "lc_min_hack"]: for r1r2 in ["eye", "hadamard"]: for sq_lc in ["sq", "lc"]: w_bits = 4 os.environ["SCALES_IMPL"] = scales_impl model = ( get_sq_model(r1r2=r1r2, w_bits=w_bits) if sq_lc == "sq" else get_lc_model(r1r2=r1r2, w_bits=w_bits) ).to("cuda") SAVE_DIR = model_id.split("/")[1] + f"-{scales_impl}-{r1r2}-w4a16" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) tokenizer.save_pretrained(SAVE_DIR) del model del tokenizer torch.cuda.empty_cache() results = lm_eval.simple_evaluate( # 1) hf in-memory # model=lm_eval.models.huggingface.HFLM( # pretrained=model, # batch_size=32, # add_bos_token=False, # ), # 1/) # 2) vllm serialized model="vllm", model_args={ "pretrained": SAVE_DIR, "add_bos_token": False, "dtype": "auto", "max_model_len": 4096, "gpu_memory_utilization": 0.5, "enable_chunked_prefill": True, }, # 2/) # 3) hf serialized # model="hf", # model_args={ # "pretrained": SAVE_DIR, # "add_bos_token": False, # "dtype": "auto", # }, # device="cuda", # 3/) tasks=["gsm8k_llama", "gsm8k", "arc_challenge_llama"], num_fewshot=8, batch_size=32, apply_chat_template=True, fewshot_as_multiturn=True, ) print( f"RESULTS, {model_id} {sq_lc} R1R2 {r1r2} W_BITS {w_bits} SCALEIMPL {scales_impl}" ) print(lm_eval.utils.make_table(results)) ``` </details> ## Follow Ups ## * Infer data free pipeline, even if a transform modifier is included * Rotations R3 and R4 * Modify example to use GPTQ once basic evaluation has been performed --------- Signed-off-by: Kyle Sayers <[email protected]> Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Kyle Sayers <[email protected]>
1 parent d5a6a4b commit 8747bae

File tree

10 files changed

+497
-0
lines changed

10 files changed

+497
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
from llmcompressor import oneshot
4+
from llmcompressor.modifiers.quantization import QuantizationModifier
5+
from llmcompressor.modifiers.transform import SpinQuantModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
8+
# Select model and load it.
9+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
10+
11+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13+
14+
# NOTE: currently only fused rotations (R1 & R2) are available
15+
# Learned rotations and online rotations (R3 & R4) will be added
16+
# in a future release.
17+
# Configure the quantization algorithm to run.
18+
# * apply spinquant transforms to model to reduce quantization loss
19+
# * quantize the weights to 4 bit with group size 128
20+
recipe = [
21+
SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"),
22+
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
23+
]
24+
25+
# Apply algorithms.
26+
oneshot(model=model, recipe=recipe, pipeline="datafree")
27+
28+
# Confirm generations of the quantized model look sane.
29+
print("\n\n")
30+
print("========== SAMPLE GENERATION ==============")
31+
dispatch_for_generation(model)
32+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
33+
output = model.generate(input_ids, max_new_tokens=100)
34+
print(tokenizer.decode(output[0]))
35+
print("==========================================\n\n")
36+
37+
# Save to disk compressed.
38+
SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2-w4a16"
39+
model.save_pretrained(SAVE_DIR, save_compressed=True)
40+
tokenizer.save_pretrained(SAVE_DIR)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# flake8: noqa
22

3+
from .fuse import *
34
from .prepare import *
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .spinquant import SpinQuantModifier
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .base import *
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
from enum import Enum
2+
from typing import Iterable, List, Literal, Optional
3+
4+
import torch
5+
from compressed_tensors import match_modules_set, match_named_modules
6+
from compressed_tensors.transform import (
7+
TransformArgs,
8+
TransformConfig,
9+
TransformScheme,
10+
apply_transform_config,
11+
)
12+
from compressed_tensors.utils import TorchDtype
13+
from pydantic import Field, ValidationInfo, field_validator
14+
from transformers import PreTrainedModel
15+
16+
from llmcompressor.core import Event, EventType, State
17+
from llmcompressor.modeling import center_embeddings, fuse_norm_linears
18+
from llmcompressor.modifiers import Modifier
19+
20+
from .mappings import SpinQuantMapping, infer_mapping_from_model
21+
from .norm_mappings import NormMapping, infer_norm_mapping_from_model
22+
23+
24+
class SpinquantRotation(str, Enum):
25+
R1 = "R1"
26+
R2 = "R2"
27+
R3 = "R3"
28+
R4 = "R4"
29+
30+
31+
class SpinQuantModifier(Modifier, use_enum_values=True):
32+
"""
33+
Implements the transforms according to "SpinQuant: LLM quantization
34+
with learned rotations" (https://arxiv.org/abs/2405.16406)
35+
36+
Transforms (rotations) are extra layers added to a model which reduce the accuracy
37+
loss induced by quantization. This is achived through "rotating" weights and
38+
activations into a space with a smaller dynamic range of values, thus decreasing
39+
the range of scales required for quantization.
40+
41+
The SpinQuant authors describe four different rotations which can be applied to a
42+
model. R1 and R2 are "offline" rotations, meaning that they can be fused into
43+
existing weights and therefore do not induce runtime cost. R3 and R4 are "online"
44+
rotations, meaning that they require additional computation at runtime.
45+
46+
Lifecycle:
47+
- on_initialize
48+
- infer SpinQuantMappings & NormMappings
49+
- as needed, create transform schemes for R1, R2, R3, & R4
50+
- on_start
51+
- normalize embeddings
52+
- fuse norm layers into subsequent Linear layers
53+
- apply TransformConfig
54+
- fuse transforms into weights for mergeable transforms
55+
- add hooks for online transforms
56+
- on sequential epoch end
57+
- on_end
58+
- on_finalize
59+
60+
:param rotations: A list containing the names of rotations to apply to the model.
61+
Possible rotations include R1, R2, R3, and R4
62+
:param transform_type: The type of transform to apply to the model.
63+
`"hadamard"` has the least performance cost but only supports sizes which are
64+
powers of power of two.
65+
`"random-matrix"` has more performance cost, but supports a much larger set of
66+
sizes.
67+
`"random-matrix"` has the greatest performance cost, but supports any size
68+
:param randomize: if True, create distinct transforms for each application
69+
:param learnable: if True, attach gradients to transform weights for training
70+
:param precision: Precision at which all transforms should be applied. This applies
71+
to both weight fusing and online rotations
72+
:param mappings: Specifies layers within a model to target for transforms.
73+
A mapping will be inferred if None is provided
74+
:param norm_mappings: Specifies layers within a model to target for norm fusing.
75+
A mapping will be inferred if None is provided
76+
:param transform_config: Optional transform config for overriding provided arguments
77+
"""
78+
79+
rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"])
80+
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
81+
default="hadamard"
82+
)
83+
randomize: bool = Field(default=False)
84+
learnable: bool = Field(default=False)
85+
precision: TorchDtype = Field(default=torch.float64)
86+
87+
# norm mappings separate from spinquant mappings to allow users to
88+
# override spinquant mappings with transform_config without overriding norms
89+
mappings: Optional[SpinQuantMapping] = Field(
90+
default=None,
91+
repr=False,
92+
exclude=True,
93+
)
94+
norm_mappings: Optional[List[NormMapping]] = Field(
95+
default=None,
96+
repr=False,
97+
exclude=True,
98+
)
99+
100+
# optional override for more fine-grained control
101+
# also included in recipe serialization
102+
transform_config: Optional[TransformConfig] = Field(default=None, repr=False)
103+
104+
@field_validator("randomize", "learnable", mode="before")
105+
def validate_not_implemented(cls, value, info: ValidationInfo):
106+
if value:
107+
raise NotImplementedError(f"{info.field_name} is not supported right now")
108+
return value
109+
110+
@field_validator("rotations", mode="before")
111+
def validate_rotations(cls, value):
112+
if isinstance(value, Iterable):
113+
return tuple(v.upper() for v in value)
114+
return value
115+
116+
def on_initialize(self, state: State, **kwargs) -> bool:
117+
if self.transform_config is not None:
118+
return True
119+
120+
self.mappings = infer_mapping_from_model(state.model)
121+
self.norm_mappings = infer_norm_mapping_from_model(state.model)
122+
123+
config_groups = {}
124+
if SpinquantRotation.R1 in self.rotations:
125+
config_groups["R1"] = self._create_r1_scheme()
126+
127+
if SpinquantRotation.R2 in self.rotations:
128+
config_groups["R2"] = self._create_r2_scheme(state.model)
129+
130+
if SpinquantRotation.R3 in self.rotations:
131+
config_groups["R3"] = self._create_r3_scheme()
132+
133+
if SpinquantRotation.R4 in self.rotations:
134+
config_groups["R4"] = self._create_r4_scheme()
135+
136+
self.transform_config = TransformConfig(config_groups=config_groups)
137+
138+
return True
139+
140+
def on_start(self, state: State, event: Event, **kwargs):
141+
self.started_ = True
142+
143+
# needs to happen after the model has been hooked to execute on the GPU
144+
# otherwise we're applying weight transforms on CPU
145+
self._center_embeddings(state.model)
146+
self._fuse_norms(state.model)
147+
apply_transform_config(state.model, self.transform_config)
148+
149+
def on_event(self, state: State, event: Event, **kwargs):
150+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
151+
if not self.started_:
152+
self.on_start(state, None)
153+
154+
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
155+
pass
156+
157+
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
158+
if not self.ended_:
159+
self.on_end(state, None)
160+
161+
def on_end(self, state: State, event: Event, **kwargs):
162+
self.ended_ = True
163+
164+
def on_finalize(self, state: State, **kwargs) -> bool:
165+
if not self.ended_:
166+
self.on_end(state, None)
167+
168+
return True
169+
170+
def _center_embeddings(self, model: PreTrainedModel):
171+
for _, embedding in match_named_modules(
172+
model, [self.mappings.embedding], warn_on_fail=True
173+
):
174+
center_embeddings(embedding)
175+
176+
def _fuse_norms(self, model: PreTrainedModel):
177+
for mapping in self.norm_mappings:
178+
for norm, *linears in match_modules_set(
179+
model, (mapping.norm, *mapping.linears)
180+
):
181+
fuse_norm_linears(norm, linears)
182+
183+
def _create_r1_scheme(self) -> TransformScheme:
184+
return TransformScheme(
185+
type=self.transform_type,
186+
randomize=self.randomize,
187+
requires_grad=self.learnable,
188+
precision=self.precision,
189+
apply=[
190+
TransformArgs(
191+
targets=[
192+
self.mappings.embedding,
193+
self.mappings.attn_o,
194+
*self.mappings.mlp_out,
195+
],
196+
location="weight_output",
197+
),
198+
TransformArgs(
199+
targets=[
200+
self.mappings.attn_q,
201+
self.mappings.attn_k,
202+
self.mappings.attn_v,
203+
*self.mappings.mlp_in,
204+
self.mappings.lm_head,
205+
],
206+
location="weight_input",
207+
inverse=True,
208+
),
209+
],
210+
)
211+
212+
def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme:
213+
config = model.config
214+
215+
if hasattr(config, "head_dim"):
216+
head_dim = config.head_dim
217+
elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"):
218+
head_dim = config.hidden_size // config.num_attention_heads
219+
else:
220+
raise NotImplementedError()
221+
222+
return TransformScheme(
223+
type=self.transform_type,
224+
randomize=self.randomize,
225+
requires_grad=self.learnable,
226+
precision=self.precision,
227+
head_dim=head_dim,
228+
apply=[
229+
TransformArgs(targets=[self.mappings.attn_v], location="weight_output"),
230+
TransformArgs(
231+
targets=[self.mappings.attn_o],
232+
location="weight_input",
233+
inverse=True,
234+
),
235+
],
236+
)
237+
238+
def _create_r3_scheme(self) -> TransformScheme:
239+
raise NotImplementedError(
240+
"SpinQuant R3 and R4 rotations will be added in a future release"
241+
)
242+
243+
def _create_r4_scheme(self) -> TransformScheme:
244+
raise NotImplementedError(
245+
"SpinQuant R3 and R4 rotations will be added in a future release"
246+
)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Dict, List, Optional
2+
3+
from loguru import logger
4+
from pydantic import BaseModel, Field, field_validator
5+
from transformers import PreTrainedModel
6+
7+
__all__ = ["SpinQuantMapping", "infer_mapping_from_model"]
8+
9+
10+
class SpinQuantMapping(BaseModel):
11+
"""
12+
SpinQuant needs to know the entire architecture of the model,
13+
as R1, R2, R3, and R4 rotations need to be applied to specific
14+
layers (https://arxiv.org/pdf/2405.16406 Fig. 1).
15+
16+
:param embedding: name or regex of embedding layer
17+
:param attn_q: name or regex of q_proj layer in attention block
18+
:param attn_k: name or regex of k_proj layer in attention block
19+
:param attn_v: name or regex of v_proj layer in attention block
20+
:param attn_o: name or regex of o_proj layer in attention block
21+
:param attn_head_dim: head_dim of the attention module, needed
22+
because R2 needs to be applied "head-wisely" to v_proj and
23+
o_proj
24+
:param mlp_in: list of names or regexes for the mlp blocks that
25+
receive the input to the MLP block, usually up_proj and gate_proj
26+
:param mlp_out: list of names or regexes for the mlp blocks that
27+
consitute the output of the MLP block, usually down_proj
28+
"""
29+
30+
embedding: str
31+
32+
attn_q: str
33+
attn_k: str
34+
attn_v: str
35+
attn_o: str
36+
attn_head_dim: Optional[int] = Field(default=None)
37+
38+
mlp_in: List[str] # up_proj, gate_proj
39+
mlp_out: List[str] # down_proj
40+
41+
lm_head: str
42+
43+
@field_validator("mlp_in", "mlp_out", mode="before")
44+
def cast_to_list(cls, value):
45+
if isinstance(value, str):
46+
return [value]
47+
48+
return value
49+
50+
51+
_default_mappings = SpinQuantMapping(
52+
embedding="re:.*embed_tokens$",
53+
attn_q="re:.*q_proj$",
54+
attn_k="re:.*k_proj$",
55+
attn_v="re:.*v_proj$",
56+
attn_o="re:.*o_proj$",
57+
mlp_in=["re:.*up_proj$", "re:.*gate_proj$"],
58+
mlp_out="re:.*down_proj$",
59+
lm_head="lm_head",
60+
)
61+
62+
63+
SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = {
64+
"LlamaForCausalLM": _default_mappings,
65+
}
66+
67+
68+
def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping:
69+
architecture = model.__class__.__name__
70+
if architecture not in SPINQUANT_MAPPING_REGISTRY:
71+
logger.info(
72+
f"Unrecognized model architecture {architecture}. "
73+
"Falling back to default mappings"
74+
)
75+
76+
return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings)

0 commit comments

Comments
 (0)