diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index e977110ad..03a4c0bd6 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -16,7 +16,9 @@ # * quantize the weights to fp8 with per channel via ptq # * quantize the activations to fp8 with dynamic per token recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head", "re:.*mlp.gate$"], + targets="Linear", + scheme="FP8_BLOCK", + ignore=["lm_head", "re:.*mlp.gate$"], ) # Apply quantization. diff --git a/examples/transform/quip_example.py b/examples/transform/quip_example.py new file mode 100644 index 000000000..61a46866b --- /dev/null +++ b/examples/transform/quip_example.py @@ -0,0 +1,43 @@ +""" +NOTE: Models produced by this example will not be runnable in vLLM without +the following changes: https://github.com/vllm-project/vllm/pull/22486 +""" + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import QuIPModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +# NOTE: because the datafree pipeline is being used in this +# example, you can use additional GPUs to support larger models +MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" +model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + QuIPModifier(transform_type="random-hadamard"), + QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot(model=model, recipe=recipe, pipeline="datafree") + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-quip-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py index 9956d0340..eaa714183 100644 --- a/src/llmcompressor/modifiers/transform/__init__.py +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa +from .quip import QuIPModifier from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/quip/__init__.py b/src/llmcompressor/modifiers/transform/quip/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py new file mode 100644 index 000000000..320ab6df0 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -0,0 +1,141 @@ +from typing import List, Literal, Optional, Union + +import torch +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from compressed_tensors.utils import TorchDtype +from pydantic import Field, ValidationInfo, field_validator + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modifiers import Modifier + +__all__ = ["QuIPModifier"] + + +class QuIPModifier(Modifier): + """ + Implements the transforms according to + [QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks](https://arxiv.org/pdf/2402.04396) + [QuIP: 2-Bit Quantization of Large Language Models With Guarantees](https://arxiv.org/abs/2307.13304) + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achieved through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + QuIP and QuIP# apply transforms to every linear layer, two of which are fused into + the model weights and two of which remain as online rotations computed at runtime. + + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-hadamard"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: If true, create distinct transforms for each application + :param learnable: If true, attach gradients to transform weights for training + :param precision: Precision at which all transforms should be applied. This applies + to both weight fusing and online rotations + :param ignore: Modules to ignore when attaching transforms + :param transform_config: Optional transform config for overriding provided arguments + """ # noqa: E501 + + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="random-hadamard" + ) + targets: Union[List[str], str] = Field(default="str") + randomize: bool = Field(default=False) + learnable: bool = Field(default=False) + precision: TorchDtype = Field(default=torch.float64) + ignore: Union[str, List[str]] = Field(default="lm_head") + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + if value: + raise NotImplementedError(f"{info.field_name} is not supported right now") + return value + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.transform_config = self._create_config() + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _create_config(self) -> TransformConfig: + return TransformConfig( + config_groups={ + "v": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=self.targets, + location="input", # non-mergable + ignore=self.ignore, + ), + TransformArgs( + targets=self.targets, + location="weight_input", + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + ), + "u": TransformScheme( + type=self.transform_type, + apply=[ + TransformArgs( + targets=self.targets, + location="weight_output", + ignore=self.ignore, + ), + TransformArgs( + targets=self.targets, + location="output", # non-mergable + inverse=True, + ignore=self.ignore, + ), + ], + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + ), + } + ) diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py index 783338eed..010c7d755 100644 --- a/tests/llmcompressor/modifiers/transform/test_correctness.py +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -5,7 +5,7 @@ from transformers import AutoModelForCausalLM from llmcompressor.core import State -from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( untie_word_embeddings, ) @@ -20,10 +20,10 @@ @pytest.mark.parametrize( "modifier,model_dtype,precision,exp_mse", [ - # (QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 - # (QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022 - # (QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10 - # (QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11 + (QuIPModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0019 + (QuIPModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0022 + (QuIPModifier, torch.float32, torch.float32, 5e-10), # 1.0e-10 + (QuIPModifier, torch.float32, torch.float64, 5e-11), # 2.7e-11 (SpinQuantModifier, torch.bfloat16, torch.bfloat16, 5e-3), # 0.0030 (SpinQuantModifier, torch.bfloat16, torch.float32, 5e-3), # 0.0029 (SpinQuantModifier, torch.float32, torch.float32, 5e-4), # 4e-4 diff --git a/tests/llmcompressor/modifiers/transform/test_serialization.py b/tests/llmcompressor/modifiers/transform/test_serialization.py index 2a2e8602d..b4be96c0d 100644 --- a/tests/llmcompressor/modifiers/transform/test_serialization.py +++ b/tests/llmcompressor/modifiers/transform/test_serialization.py @@ -1,9 +1,9 @@ import pytest -from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.modifiers.transform import QuIPModifier, SpinQuantModifier -@pytest.mark.parametrize("modifier", [SpinQuantModifier]) +@pytest.mark.parametrize("modifier", [SpinQuantModifier, QuIPModifier]) def test_reload(modifier): instance = modifier(transform_type="hadamard") dump = instance.model_dump()