diff --git a/compress_model.py b/compress_model.py new file mode 100644 index 000000000..fa67bead0 --- /dev/null +++ b/compress_model.py @@ -0,0 +1,60 @@ +# python3 compress_model.py --model_id meta-llama/Llama-3.2-1B-Instruct --transform_type random-hadamard +import argparse +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, help="Model stub to compress") + parser.add_argument("--transform_type", type=str, default=None, help="Type of transform used in SpinQuantModifier") + parser.add_argument("--scheme", type=str, default=None, help="Quantization scheme (e.g. W4A16)") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + # Select model and load it. + MODEL_ID = args.model_id + model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto") + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + # Select number of samples. 512 samples is a good place to start. + # Increasing the number of samples can improve accuracy. + NUM_CALIBRATION_SAMPLES = 512 + MAX_SEQUENCE_LENGTH = 2048 + + # Configure the quantization algorithm to run. + recipe = [] + if args.transform_type: + recipe.append(SpinQuantModifier(rotations=["R1", "R2"], transform_type=args.transform_type)) + + if args.scheme: + recipe.append(QuantizationModifier(targets="Linear", scheme=args.scheme, ignore=["lm_head"])) + + # Apply algorithms. + oneshot( + model=model, + recipe=recipe, + dataset="ultrachat_200k", + splits={"calibration": f"train_sft[:{NUM_CALIBRATION_SAMPLES}]"}, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + ) + + # Confirm generations of the quantized model look sane. + print("\n\n") + print("========== SAMPLE GENERATION ==============") + dispatch_for_generation(model) + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) + print("==========================================\n\n") + + # Save to disk compressed. + SAVE_DIR = MODEL_ID.split("/")[1] + f"-{args.transform_type}-{args.scheme}" + model.save_pretrained(SAVE_DIR, save_compressed=True) + tokenizer.save_pretrained(SAVE_DIR) diff --git a/examples/transform/spinquant_example.py b/examples/transform/spinquant_example.py new file mode 100644 index 000000000..6671af923 --- /dev/null +++ b/examples/transform/spinquant_example.py @@ -0,0 +1,91 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype="auto", +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, attn_implementation="eager") + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 512 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * apply spinquant transforms to model in order to make quantization easier +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = [ + SpinQuantModifier( + rotations=["R1", "R2", "R3", "R4"], transform_type="random-hadamard" + ), + # QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]), +] + +# Apply algorithms. +oneshot( + model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +from llmcompressor.utils import calibration_forward_context + +with calibration_forward_context(model): + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") + output = model.generate(input_ids, max_new_tokens=100) + print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-transformed-w4a16" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py index 5647e4d06..95ec832fb 100644 --- a/src/llmcompressor/entrypoints/utils.py +++ b/src/llmcompressor/entrypoints/utils.py @@ -20,7 +20,7 @@ from llmcompressor.pytorch.model_load.helpers import parse_dtype from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_save_pretrained, - patch_tied_tensors_bug, + untie_word_embeddings, ) from llmcompressor.transformers.utils.helpers import ( detect_last_checkpoint, @@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"): ) # untie tie_word_embeddings weights - patch_tied_tensors_bug(model_args.model) + if not model_args.tie_word_embeddings: + untie_word_embeddings(model_args.model) # wrap model.save_pretrained modify_save_pretrained(model_args.model) @@ -143,7 +144,6 @@ def initialize_model_from_path( cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) @@ -156,7 +156,6 @@ def initialize_model_from_path( AutoConfig.from_pretrained( model_args.distill_teacher, use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) if model_args.distill_teacher diff --git a/src/llmcompressor/modeling/__init__.py b/src/llmcompressor/modeling/__init__.py index e2c22ed1f..76b6b0391 100644 --- a/src/llmcompressor/modeling/__init__.py +++ b/src/llmcompressor/modeling/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa +from .fuse import * from .prepare import * diff --git a/src/llmcompressor/modeling/fuse.py b/src/llmcompressor/modeling/fuse.py new file mode 100644 index 000000000..33e91601c --- /dev/null +++ b/src/llmcompressor/modeling/fuse.py @@ -0,0 +1,58 @@ +from typing import Iterable + +import torch +from compressed_tensors import ( + align_module_device, + get_execution_device, + update_offload_parameter, +) +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +__all__ = ["normalize_embedding", "fuse_norm_linears"] + + +PRECISION = torch.float64 + + +def normalize_embedding(embedding: torch.nn.Module): + if isinstance(embedding, (torch.nn.Embedding)): + with align_module_device(embedding): + weight_dtype = embedding.weight.dtype + weight = embedding.weight.to(PRECISION) + new_weight = weight - weight.mean(dim=-1, keepdim=True) + new_weight = new_weight.to(weight_dtype) + + update_offload_parameter(embedding, "weight", new_weight) + + else: + raise ValueError(f"Cannot normalize embedding of type {type(embedding)}") + + +def fuse_norm_linears(norm: torch.nn.Module, linears: Iterable[torch.nn.Linear]): + """ + Fuse a norm layer into subsequent linear layers. This useful for ensuring transform + invariance between norm and linear layers. + + Note that a model cannot be properly trained after its norms have been fused + + :param norm: norm layer whose weight will be fused into subsequent linears + :param linears: linear layers which directly follow the norm layer + """ + if isinstance(norm, (torch.nn.RMSNorm, LlamaRMSNorm)): + for linear in linears: + # NOTE: spinquant does this op in float64 + exec_device = get_execution_device(norm) + with align_module_device(norm, exec_device), align_module_device( + linear, exec_device + ): + weight_dtype = linear.weight.dtype + new_weight = linear.weight.to(PRECISION) * norm.weight.to(PRECISION) + new_weight = new_weight.to(weight_dtype) + + update_offload_parameter(linear, "weight", new_weight) + + new_norm_weight = torch.ones_like(norm.weight, device="cpu") + update_offload_parameter(norm, "weight", new_norm_weight) + + else: + raise ValueError(f"Cannot fuse norm of type {type(norm)}") diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..fe824695e 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Tuple +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple import torch from compressed_tensors.quantization import ( @@ -13,11 +14,18 @@ from compressed_tensors.utils import align_module_device, update_parameter_data from loguru import logger from torch.nn import Module +from torch.utils.hooks import RemovableHandle from llmcompressor.modifiers.quantization.cache import QuantizedKVParameterCache from llmcompressor.observers import Observer from llmcompressor.utils.helpers import getattr_chain +if TYPE_CHECKING: + from compressed_tensors.modeling.attention import CompressedAttentionImpl + + from llmcompressor.modifiers.utils.hooks import HooksMixin + + DEFAULT_MAXSHRINK = 0.20 DEFAULT_PATIENCE = 5 DEFAULT_AVERAGING_CONSTANT = 0.01 @@ -25,6 +33,7 @@ DEFAULT_NORM = 2.4 __all__ = [ + "register_calibrate_attn_hooks", "initialize_observer", "update_weight_zp_scale", "calibrate_input_hook", @@ -205,14 +214,30 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): ) -def calibrate_input_hook(module: Module, args: Any): +def register_calibrate_attn_hooks( + modifier: "HooksMixin", attention_impl: "CompressedAttentionImpl" +) -> Set[RemovableHandle]: + return { + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="q"), "query" + ), + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="k"), "key" + ), + modifier.register_hook( + attention_impl, partial(calibrate_input_hook, basename="v"), "value" + ), + } + + +def calibrate_input_hook(module: Module, args: Any, base_name: str = "input"): """ Hook to calibrate input activations. Will call the observers to update the scales/zp before applying input QDQ in the module's forward pass. """ args = args[0] if isinstance(args, tuple) else args - calibrate_activations(module, value=args, base_name="input") + calibrate_activations(module, value=args, base_name=base_name) def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): @@ -282,6 +307,14 @@ def initialize_quantized_kv_cache(module: Module): setattr(module, "kv_cache", quantized_kv_cache) +def initialize_attention_observers(module: Module): + input_args = getattr_chain(module, "quantization_scheme.input_activations", None) + if input_args is not None: + initialize_observer(module, "q", input_args) + initialize_observer(module, "k", input_args) + initialize_observer(module, "v", input_args) + + def apply_calibration_status(module: Module): scheme = getattr(module, "quantization_scheme", None) if not scheme: diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index d193d85a1..7c7a41033 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -232,6 +232,7 @@ def _initialize_observers(self, module: torch.nn.Module): # kv_cache activations. Within `apply_quantization_config`, the config is # modified to use attention output quantization if a kv_cache_scheme exists if is_attention and output: + # initialize_attention_observers(module) # TODO: attnq initialize_quantized_kv_cache(module) # output activations @@ -240,6 +241,7 @@ def _initialize_observers(self, module: torch.nn.Module): def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() + for module in model.modules(): if not hasattr(module, "quantization_scheme"): continue @@ -258,6 +260,11 @@ def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: self.register_hook(module, calibrate_input_hook, "forward_pre") ) + # TODO: attnq + # if is_attention: + # attention_impl = CompressedAttentionImpl.from_module(module) + # hooks |= register_calibrate_attn_hooks(self, attention_impl) + # kv_cache activations. Within `apply_quantization_config`, the config is # modified to use attention output quantization if a kv_cache_scheme exists if is_attention and output: diff --git a/src/llmcompressor/modifiers/transform/__init__.py b/src/llmcompressor/modifiers/transform/__init__.py new file mode 100644 index 000000000..9956d0340 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .spinquant import SpinQuantModifier diff --git a/src/llmcompressor/modifiers/transform/spinquant/__init__.py b/src/llmcompressor/modifiers/transform/spinquant/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py new file mode 100644 index 000000000..bd78525d3 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -0,0 +1,260 @@ +from enum import Enum +from typing import Iterable, List, Literal, Optional + +from compressed_tensors import match_modules_set, match_named_modules +from compressed_tensors.transform import ( + TransformArgs, + TransformConfig, + TransformScheme, + apply_transform_config, +) +from pydantic import Field, ValidationInfo, field_validator +from transformers import PreTrainedModel + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modeling import fuse_norm_linears, normalize_embedding +from llmcompressor.modifiers import Modifier + +from .mappings import SpinQuantMapping, infer_mapping_from_model +from .norm_mappings import NormMapping, infer_norm_mapping_from_model + + +class SpinquantRotation(str, Enum): + R1 = "R1" + R2 = "R2" + R3 = "R3" + R4 = "R4" + + +class SpinQuantModifier(Modifier, use_enum_values=True): + """ + Implements the transforms according to + [SpinQuant: LLM quantization with learned rotations](https://arxiv.org/abs/2405.16406) # noqa: E501 + + Transforms (rotations) are extra layers added to a model which reduce the accuracy + loss induced by quantization. This is achived through "rotating" weights and + activations into a space with a smaller dynamic range of values, thus decreasing + the range of scales required for quantization. + + The SpinQuant authors describe four different rotations which can be applied to a + model. R1 and R2 are "offline" rotations, meaning that they can be fused into + existing weights and therefore do not induce runtime cost. R3 and R4 are "online" + rotations, meaning that they require additional computation at runtime. + + :param rotations: A list containing the names of rotations to apply to the model. + Possible rotations include R1, R2, R3, and R4 + :param transform_type: The type of transform to apply to the model. + `"hadamard"` has the least performance cost but only supports sizes which are + powers of power of two. + `"random-matrix"` has more performance cost, but supports a much larger set of + sizes. + `"random-matrix"` has the greatest performance cost, but supports any size + :param randomize: if True, create distinct transforms for each application + :param learnable: if True, attach gradients to transform weights for training + :param mappings: Specifies layers within a model to target for transforms. + A mapping will be inferred if None is provided + :param norm_mappings: Specifies layers within a model to target for norm fusing. + A mapping will be inferred if None is provided + :param transform_config: Optional transform config for overriding provided arguments + """ + + rotations: List[SpinquantRotation] = Field( + default_factory=lambda: ["R1", "R2"], exclude=True + ) + transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field( + default="hadamard", exclude=True + ) + randomize: bool = Field(default=False, exclude=True) + learnable: bool = Field(default=False, exclude=True) + + # norm mappings separate from spinquant mappings to allow users to + # override spinquant mappings with transform_config without overriding norms + mappings: Optional[SpinQuantMapping] = Field( + default=None, + repr=False, + exclude=True, + ) + norm_mappings: Optional[List[NormMapping]] = Field( + default=None, + repr=False, + exclude=True, + ) + + # optional override for more fine-grained control + # also included in recipe serialization + transform_config: Optional[TransformConfig] = Field(default=None, repr=False) + + @field_validator("randomize", "learnable", mode="before") + def validate_not_implemented(cls, value, info: ValidationInfo): + raise NotImplementedError(f"{info.field_name} is not supported right now") + + @field_validator("rotations", mode="before") + def validate_rotations(cls, value): + if isinstance(value, Iterable): + return tuple(v.upper() for v in value) + return value + + def on_initialize(self, state: State, **kwargs) -> bool: + if self.transform_config is not None: + return True + + self.mappings = infer_mapping_from_model(state.model) + self.norm_mappings = infer_norm_mapping_from_model(state.model) + + config_groups = {} + if SpinquantRotation.R1 in self.rotations: + config_groups["R1"] = self._create_r1_scheme() + + if SpinquantRotation.R2 in self.rotations: + config_groups["R2"] = self._create_r2_scheme(state.model) + + if SpinquantRotation.R3 in self.rotations: + config_groups["R3"] = self._create_r3_scheme(state.model) + + if SpinquantRotation.R4 in self.rotations: + config_groups["R4"] = self._create_r4_scheme() + + self.transform_config = TransformConfig(config_groups=config_groups) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # needs to happen after the model has been hooked to execute on the GPU + # otherwise we're applying weight transforms on CPU + self._prenormalize_embeddings(state.model) + self._fuse_norms(state.model) + apply_transform_config(state.model, self.transform_config) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + elif event.type_ == EventType.SEQUENTIAL_EPOCH_END: + pass + + elif event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + + def on_finalize(self, state: State, **kwargs) -> bool: + if not self.ended_: + self.on_end(state, None) + + return True + + def _prenormalize_embeddings(self, model: PreTrainedModel): + for _, embedding in match_named_modules( + model, [self.mappings.embedding], warn_on_fail=True + ): + normalize_embedding(embedding) + + def _fuse_norms(self, model: PreTrainedModel): + for mapping in self.norm_mappings: + for norm, *linears in match_modules_set( + model, (mapping.norm, *mapping.linears) + ): + fuse_norm_linears(norm, linears) + + def _create_r1_scheme(self) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[ + self.mappings.embedding, + self.mappings.attn_o, + *self.mappings.mlp_out, + ], + location="weight_output", + ), + TransformArgs( + targets=[ + self.mappings.attn_q, + self.mappings.attn_k, + self.mappings.attn_v, + *self.mappings.mlp_in, + self.mappings.lm_head, + ], + location="weight_input", + inverse=True, + ), + ], + ) + + def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs(targets=[self.mappings.attn_v], location="weight_output"), + TransformArgs( + targets=[self.mappings.attn_o], + location="weight_input", + inverse=True, + ), + ], + ) + + def _create_r3_scheme(self, model: PreTrainedModel) -> TransformScheme: + config = model.config + + if hasattr(config, "head_dim"): + head_dim = config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + head_dim = config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() + + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + head_dim=head_dim, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="attn_q", + ), + TransformArgs( + targets=[self.mappings.attn], + location="attn_k", + ), + ], + ) + + def _create_r4_scheme(self) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + apply=[ + TransformArgs( + targets=[*self.mappings.mlp_out], + location="input", + ), + TransformArgs( + targets=[*self.mappings.mlp_out], + location="weight_input", + inverse=True, + ), + ], + ) diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py new file mode 100644 index 000000000..36102b975 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -0,0 +1,59 @@ +from typing import Dict, List, Optional + +from loguru import logger +from pydantic import BaseModel, Field, field_validator +from transformers import PreTrainedModel + +__all__ = ["SpinQuantMapping", "infer_mapping_from_model"] + + +class SpinQuantMapping(BaseModel): + embedding: str + + attn: str + attn_q: str + attn_k: str + attn_v: str + attn_o: str + attn_head_dim: Optional[int] = Field(default=None) + + mlp_in: List[str] # up_proj, gate_proj + mlp_out: List[str] # down_proj + + lm_head: str + + @field_validator("mlp_in", "mlp_out", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = SpinQuantMapping( + embedding="re:.*embed_tokens$", + attn="re:.*self_attn$", + attn_q="re:.*q_proj$", + attn_k="re:.*k_proj$", + attn_v="re:.*v_proj$", + attn_o="re:.*o_proj$", + mlp_in=["re:.*up_proj$", "re:.*gate_proj$"], + mlp_out="re:.*down_proj$", + lm_head="lm_head", +) + + +SPINQUANT_MAPPING_REGISTRY: Dict[str, SpinQuantMapping] = { + "LlamaForCausalLM": _default_mappings, +} + + +def infer_mapping_from_model(model: PreTrainedModel) -> SpinQuantMapping: + architecture = model.__class__.__name__ + if architecture not in SPINQUANT_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return SPINQUANT_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py new file mode 100644 index 000000000..0752f6986 --- /dev/null +++ b/src/llmcompressor/modifiers/transform/spinquant/norm_mappings.py @@ -0,0 +1,50 @@ +from typing import Dict, List + +from loguru import logger +from pydantic import BaseModel, field_validator +from transformers import PreTrainedModel + +__all__ = ["infer_norm_mapping_from_model"] + + +class NormMapping(BaseModel): + norm: str + linears: List[str] + + @field_validator("linears", mode="before") + def cast_to_list(cls, value): + if isinstance(value, str): + return [value] + + return value + + +_default_mappings = [ + NormMapping( + norm="re:.*input_layernorm$", + linears=["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"], + ), + NormMapping( + norm="re:.*post_attention_layernorm$", + linears=["re:.*up_proj$", "re:.*gate_proj$"], + ), + NormMapping( + norm="model.norm", + linears=["lm_head"], + ), +] + +NORM_MAPPING_REGISTRY: Dict[str, NormMapping] = { + "LlamaForCausalLM": _default_mappings, +} + + +def infer_norm_mapping_from_model(model: PreTrainedModel) -> List[NormMapping]: + architecture = model.__class__.__name__ + if architecture not in NORM_MAPPING_REGISTRY: + logger.info( + f"Unrecognized model architecture {architecture}. " + "Falling back to default mappings" + ) + + return NORM_MAPPING_REGISTRY.get(architecture, _default_mappings) diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py index 587f7ca69..7ad6d56dc 100644 --- a/src/llmcompressor/pipelines/data_free/pipeline.py +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -5,6 +5,7 @@ from llmcompressor.core.session_functions import LifecycleCallbacks from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.dev import dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -27,5 +28,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + # some ops are still performed on the model by modifiers + # we want those ops to occur on the GPU + dispatch_for_generation(model) + LifecycleCallbacks.calibration_epoch_start() LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py index 2c1a54cf5..67d510d13 100644 --- a/src/llmcompressor/pipelines/registry.py +++ b/src/llmcompressor/pipelines/registry.py @@ -8,6 +8,7 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.modifiers.transform import SpinQuantModifier if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -61,4 +62,8 @@ def _infer_pipeline(modifiers: List[Modifier]) -> str: if not config.requires_calibration_data(): return "datafree" + # TODO: Remove hardcode + if len(modifiers) == 1 and isinstance(modifiers[0], SpinQuantModifier): + return "datafree" + return "sequential" diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 69b0e3f28..1495f6d06 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -9,11 +9,11 @@ CompressionFormat, ModelCompressor, SparsityCompressionConfig, + delete_offload_parameter, is_module_offloaded, - update_offload_parameter, + register_offload_parameter, ) from loguru import logger -from safetensors.torch import storage_ptr from transformers import PreTrainedModel from llmcompressor.core import active_session @@ -27,7 +27,7 @@ from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path -__all__ = ["modify_save_pretrained"] +__all__ = ["modify_save_pretrained", "untie_word_embeddings"] def modify_save_pretrained(model: PreTrainedModel): @@ -120,7 +120,7 @@ def save_pretrained_wrapper( model.save_pretrained = save_pretrained_compressed(model.save_pretrained) -def patch_tied_tensors_bug(model: torch.nn.Module): +def untie_word_embeddings(model: PreTrainedModel): """ Patches bug where HF transformers will fail to untie weights under specific circumstances (https://github.com/huggingface/transformers/issues/33689). @@ -129,28 +129,27 @@ def patch_tied_tensors_bug(model: torch.nn.Module): :param model: model to fix """ - if ( - hasattr(model.config, "tie_word_embeddings") - and not model.config.tie_word_embeddings - ): - input_embed = model.get_input_embeddings() - output_embed = model.get_output_embeddings() - - if input_embed is None or output_embed is None: - # some models fail to properly override the abstract methods - return - - if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight): - for module in (input_embed, output_embed): - if not is_module_offloaded(module): - # create new storage ptr for onloaded weight - untied_data = module.weight.data.clone() - module.weight.data = untied_data - else: - # create new storage ptr for offloaded weight - # note `update_offload_parameter` does not create a new storage ptr - untied_data = module._hf_hook.weights_map["weight"].clone() - update_offload_parameter(module, "weight", untied_data) + input_embed = model.get_input_embeddings() + output_embed = model.get_output_embeddings() + + for module in (input_embed, output_embed): + if module is None or not hasattr(module, "weight"): + logger.warning(f"Cannot untie {module} which does not have weight param") + continue + + # this could be replaced by a `get_offloaded_parameter` util + if not is_module_offloaded(module): + untied_data = module.weight.data.clone() + else: + untied_data = module._hf_hook.weights_map["weight"].clone() + + requires_grad = module.weight.requires_grad + new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad) + delete_offload_parameter(module, "weight") + register_offload_parameter(module, "weight", new_parameter) + + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False def get_model_compressor( diff --git a/tests/llmcompressor/modifiers/transform/test_correctness.py b/tests/llmcompressor/modifiers/transform/test_correctness.py new file mode 100644 index 000000000..660bab0ef --- /dev/null +++ b/tests/llmcompressor/modifiers/transform/test_correctness.py @@ -0,0 +1,34 @@ +import pytest +import torch +from compressed_tensors.transform import apply_transform_config +from transformers import AutoModelForCausalLM + +from llmcompressor.modifiers.transform.template.quip import QUIP + + +@pytest.mark.parametrize( + "dtype,exp_max,exp_mse", + [ + ( + torch.bfloat16, + 1.1, + 0.012, + ), # constructing and running transforms in float32 can improve to (~0.6562, ~0.0055) # noqa: E501 + (torch.float32, 4e-4, 2e-9), + ], +) +def test_apply_correctness(dtype, exp_max, exp_mse): + model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct", device_map="cuda", torch_dtype=dtype + ) + + input = {k: v.to("cuda") for k, v in model.dummy_inputs.items()} + with torch.no_grad(): + true_output = model(**input) + + apply_transform_config(model, QUIP) + with torch.no_grad(): + output = model(**input) + + assert torch.max(true_output.logits - output.logits) <= exp_max + assert torch.nn.MSELoss()(output.logits, true_output.logits) <= exp_mse diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index 140e706d1..aad551ff8 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -28,7 +28,7 @@ from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( get_model_compressor, modify_save_pretrained, - patch_tied_tensors_bug, + untie_word_embeddings, ) from tests.testing_utils import requires_gpu @@ -224,8 +224,6 @@ def test_quant_model_reload(format, dtype, tmp_path): shutil.rmtree(tmp_path) -# technically only tie_word_embeddings=False is supported right now -# setting to True is discouraged @pytest.mark.parametrize( "offload,torch_dtype,tie_word_embeddings,device", [ @@ -237,25 +235,23 @@ def test_quant_model_reload(format, dtype, tmp_path): # offloading (True, torch.float16, False, "cpu"), (True, torch.float32, False, "cpu"), - # (True, torch.float16, True, "cpu"), # TODO: fails - # (True, torch.float32, True, "cpu"), # TODO: fails + (True, torch.float16, True, "cpu"), + (True, torch.float32, True, "cpu"), ], ) def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path): model_path = "nm-testing/llama2.c-stories15M" save_path = tmp_path / "save_path" - model = AutoModelForCausalLM.from_pretrained( - model_path, - tie_word_embeddings=tie_word_embeddings, - torch_dtype=torch_dtype, - ) + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) if offload: model = dispatch_model(model, {"": device}, force_hooks=True) else: model = model.to(device) - patch_tied_tensors_bug(model) + if not tie_word_embeddings: + untie_word_embeddings(model) + modify_save_pretrained(model) model.save_pretrained(save_path, safe_serialization=True) @@ -294,22 +290,18 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp (True, torch.float32, True, "cpu"), ], ) -def test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device, tmp_path -): +def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device): # load model - model = AutoModelForCausalLM.from_pretrained( - "nm-testing/llama2.c-stories15M", - torch_dtype=torch_dtype, - tie_word_embeddings=tie_word_embeddings, - ) - patch_tied_tensors_bug(model) - + model_path = "nm-testing/llama2.c-stories15M" + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) if offload: model = dispatch_model(model, {"": device}, force_hooks=True) else: model = model.to(device) + if not tie_word_embeddings: + untie_word_embeddings(model) + # modify lm head with torch.no_grad(), align_module_device(model.lm_head): update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) @@ -332,12 +324,8 @@ def test_model_shared_tensors( (False, torch.float32, True, "cuda:0"), ], ) -def test_model_shared_tensors_gpu( - offload, torch_dtype, tie_word_embeddings, device, tmp_path -): - test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device, tmp_path - ) +def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device): + test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device) @requires_gpu