diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py new file mode 100644 index 000000000..547d06041 --- /dev/null +++ b/examples/transform/spinquant_example.py @@ -0,0 +1,40 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# NOTE: currently only fused rotations (R1 & R2) are available +# Learned rotations and online rotations (R3 & R4) will be added +# in a future release. +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model to reduce quantization loss +# * quantize the weights to 4 bit with group size 128 +recipe = [ + SpinQuantModifier(rotations=["R1", "R2"], transform_type="hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot(model=model, recipe=recipe, pipeline="datafree") + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-spinquantR1R2-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index e2c22ed1f..76b6b0391 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa +from .fuse import * from .prepare import * diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..9956d0340 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py new file mode 100644 index 000000000..68095ab1b --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -0,0 +1,246 @@ +from enum import Enum +from typing import Iterable, List, Literal, Optional + +import torch +from compressed_tensors import match_modules_set, match_named_modules +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from compressed_tensors.utils import TorchDtype +from pydantic import Field, ValidationInfo, field_validator +from transformers import PreTrainedModel + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import center_embeddings, fuse_norm_linears +from llmcompressor.modifiers import Modifier + +from .mappings import SpinQuantMapping, infer_mapping_from_model +from .norm_mappings import NormMapping, infer_norm_mapping_from_model + + +class SpinquantRotation(str, Enum): + R1 = "R1" + R2 = "R2" + R3 = "R3" + R4 = "R4" + + +class SpinQuantModifier(Modifier, use_enum_values=True): + """ + Implements the transforms according to "SpinQuant: LLM quantization + with learned rotations" (https://arxiv.org/abs/2405.16406) + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + The SpinQuant authors describe four different rotations which can be applied to a + model. R1 and R2 are "offline" rotations, meaning that they can be fused into + existing weights and therefore do not induce runtime cost. R3 and R4 are "online" + rotations, meaning that they require additional computation at runtime. + + Lifecycle: + - on_initialize + - infer SpinQuantMappings & NormMappings + - as needed, create transform schemes for R1, R2, R3, & R4 + - on_start + - normalize embeddings + - fuse norm layers into subsequent Linear layers + - apply TransformConfig + - fuse transforms into weights for mergeable transforms + - add hooks for online transforms + - on sequential epoch end + - on_end + - on_finalize + + :param rotations: A list containing the names of rotations to apply to the model. + Possible rotations include R1, R2, R3, and R4 + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: if True, create distinct transforms for each application + :param learnable: if True, attach gradients to transform weights for training + :param precision: Precision at which all transforms should be applied. This applies + to both weight fusing and online rotations + :param mappings: Specifies layers within a model to target for transforms. + A mapping will be inferred if None is provided + :param norm_mappings: Specifies layers within a model to target for norm fusing. + A mapping will be inferred if None is provided + :param transform_config: Optional transform config for overriding provided arguments + """ + + rotations: List[SpinquantRotation] = Field(default_factory=lambda: ["R1", "R2"]) + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard" + ) + randomize: bool = Field(default=False) + learnable: bool = Field(default=False) + precision: TorchDtype = Field(default=torch.float64) + + # norm mappings separate from spinquant mappings to allow users to + # override spinquant mappings with transform_config without overriding norms + mappings: Optional[SpinQuantMapping] = Field( + default=None, + repr=False, + exclude=True, + ) + norm_mappings: Optional[List[NormMapping]] = Field( + default=None, + repr=False, + exclude=True, + ) + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + if value: + raise NotImplementedError(f"{info.field_name} is not supported right now") + return value + + @field_validator("rotations", mode="before") + def validate_rotations(cls, value): + if isinstance(value, Iterable): + return tuple(v.upper() for v in value) + return value + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) + + config_groups = {} + if SpinquantRotation.R1 in self.rotations: + config_groups["R1"] = self._create_r1_scheme() + + if SpinquantRotation.R2 in self.rotations: + config_groups["R2"] = self._create_r2_scheme(state.model) + + if SpinquantRotation.R3 in self.rotations: + config_groups["R3"] = self._create_r3_scheme() + + if SpinquantRotation.R4 in self.rotations: + config_groups["R4"] = self._create_r4_scheme() + + self.transform_config = TransformConfig(config_groups=config_groups) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + self._center_embeddings(state.model) + self._fuse_norms(state.model) + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _center_embeddings(self, model: PreTrainedModel): + for _, embedding in match_named_modules( + model, [self.mappings.embedding], warn_on_fail=True + ): + center_embeddings(embedding) + + def _fuse_norms(self, model: PreTrainedModel): + for mapping in self.norm_mappings: + for norm, *linears in match_modules_set( + model, (mapping.norm, *mapping.linears) + ): + fuse_norm_linears(norm, linears) + + def _create_r1_scheme(self) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + apply=[ + TransformArgs( + targets=[ + self.mappings.embedding, + self.mappings.attn_o, + *self.mappings.mlp_out, + ], + location="weight_output", + ), + TransformArgs( + targets=[ + self.mappings.attn_q, + self.mappings.attn_k, + self.mappings.attn_v, + *self.mappings.mlp_in, + self.mappings.lm_head, + ], + location="weight_input", + inverse=True, + ), + ], + ) + + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + head_dim=head_dim, + apply=[ + TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), + TransformArgs( + targets=[self.mappings.attn_o], + location="weight_input", + inverse=True, + ), + ], + ) + + def _create_r3_scheme(self) -> TransformScheme: + raise NotImplementedError( + "SpinQuant R3 and R4 rotations will be added in a future release" + ) + + def _create_r4_scheme(self) -> TransformScheme: + raise NotImplementedError( + "SpinQuant R3 and R4 rotations will be added in a future release" + ) diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py new file mode 100644 index 000000000..514d1f109 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -0,0 +1,76 @@ +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel + +__all__ = ["SpinQuantMapping", "infer_mapping_from_model"] + + +class SpinQuantMapping(BaseModel): + """ + SpinQuant needs to know the entire architecture of the model, + as R1, R2, R3, and R4 rotations need to be applied to specific + layers (https://arxiv.org/pdf/2405.16406 Fig. 1). + + :param embedding: name or regex of embedding layer + :param attn_q: name or regex of q_proj layer in attention block + :param attn_k: name or regex of k_proj layer in attention block + :param attn_v: name or regex of v_proj layer in attention block + :param attn_o: name or regex of o_proj layer in attention block + :param attn_head_dim: head_dim of the attention module, needed + because R2 needs to be applied "head-wisely" to v_proj and + o_proj + :param mlp_in: list of names or regexes for the mlp blocks that + receive the input to the MLP block, usually up_proj and gate_proj + :param mlp_out: list of names or regexes for the mlp blocks that + consitute the output of the MLP block, usually down_proj + """ + + embedding: str + + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = SpinQuantMapping( + embedding="re:.*embed_tokens$", + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + lm_head="lm_head", +) + + +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { + "LlamaForCausalLM": _default_mappings, +} + + +def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping: + architecture = model.__class__.__name__ + if architecture not in SPINQUANT_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py new file mode 100644 index 000000000..e60ac0d1a --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -0,0 +1,61 @@ +from typing import Dict, List + +from loguru import logger +from pydantic import BaseModel, field_validator +from transformers import PreTrainedModel + +__all__ = ["infer_norm_mapping_from_model"] + + +class NormMapping(BaseModel): + """ + SpinQuant needs to know where every norm layer exists in the model, + as well as all the subsequent Linear layers the norm passes into. + This is because the norm layer weights need to normalized before + transforms can be fused into Linear layers. + + :param norm: name or regex that matches norm layer in model + :param linears: list of names or regexes of Linear layers that + receive input from norm. + """ + + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), +] + +NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { + "LlamaForCausalLM": _default_mappings, +} + + +def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: + architecture = model.__class__.__name__ + if architecture not in NORM_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index 587f7ca69..7ad6d56dc 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -5,6 +5,7 @@ from llmcompressor.core.session_functions import LifecycleCallbacks from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.dev import dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,5 +28,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + # some ops are still performed on the model by modifiers + # we want those ops to occur on the GPU + dispatch_for_generation(model) + LifecycleCallbacks.calibration_epoch_start() LifecycleCallbacks.calibration_epoch_end() diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..783338eed --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,52 @@ +import os + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.core import State +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + untie_word_embeddings, +) +from tests.testing_utils import requires_gpu + + +@requires_gpu +@pytest.mark.skipif( + (not os.getenv("HF_TOKEN")), + reason="Skipping correctness tests requiring gated model access", +) +@pytest.mark.parametrize( + "modifier,model_dtype,precision,exp_mse", + [ + # (QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 + # (QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022 + # (QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10 + # (QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11 + (SpinQuantModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0030 + (SpinQuantModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0029 + (SpinQuantModifier, torch.float32, torch.float32, 5e-4), # 4e-4 + (SpinQuantModifier, torch.float32, torch.float64, 5e-4), # 4e-4 + ], +) +def test_apply_correctness(modifier, model_dtype, precision, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B-Instruct", device_map="cuda", torch_dtype=model_dtype + ) + untie_word_embeddings(model) + + state = State(model=model) + modifier = modifier(transform_type="random-hadamard", precision=precision) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + modifier.on_initialize(state) + modifier.on_start(state, None) + + with torch.no_grad(): + output = model(**input) + + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse diff --git a/tests/llmcompressor/modifiers/transform/test_serialization.py b/tests/llmcompressor/modifiers/transform/test_serialization.py new file mode 100644 index 000000000..2a2e8602d --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_serialization.py @@ -0,0 +1,10 @@ +import pytest + +from llmcompressor.modifiers.transform import SpinQuantModifier + + +@pytest.mark.parametrize("modifier", [SpinQuantModifier]) +def test_reload(modifier): + instance = modifier(transform_type="hadamard") + dump = instance.model_dump() + assert modifier.model_validate(dump) == instance