diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index 89eb780c85..215828b560 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -44,6 +44,7 @@ class EventType(Enum): BATCH_START = "batch_start" LOSS_CALCULATED = "loss_calculated" BATCH_END = "batch_end" + SEQUENTIAL_EPOCH_END = "sequential_epoch_end" # step lifecycle OPTIM_PRE_STEP = "optim_pre_step" diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index 4d12f22ff4..5d5f6163b0 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -212,5 +212,16 @@ def batch_end(cls, **kwargs) -> ModifiedState: active_session()._log_model_info() return cls.event(EventType.BATCH_END, **kwargs) + @classmethod + def sequential_epoch_end(cls, **kwargs) -> ModifiedState: + """ + Invoke a sequential epoch end event for the active session. This event should be + called after one sequential layer has been calibrated/trained for one epoch + + This is called after a sequential layer has been calibrated with one batch, see + `src/llmcompressor/pipelines/sequential/pipeline.py` for usage example + """ + return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs) + callbacks = LifecycleCallbacks diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index bcbd610fe7..7c85bd61a6 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -10,7 +10,7 @@ from loguru import logger from pydantic import PrivateAttr -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin from llmcompressor.modifiers.obcq.sgpt_sparsify import ( @@ -90,6 +90,14 @@ def calibrate_module( args: Tuple[torch.Tensor, ...], _output: torch.Tensor, ): + """ + Calibration hook used to accumulate the hessian of the input to the module + + :param module: module being calibrated + :param args: inputs to the module, the first element of which is the + cannonical input + :param _output: uncompressed module output, unused + """ # Assume that the first argument is the input inp = args[0] @@ -108,10 +116,13 @@ def calibrate_module( self._num_samples[module], ) - def on_sequential_batch_end(self): + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self.compress_modules() + + def compress_modules(self): """ - Sparsify modules - TODO: implement with event callback + Sparsify modules which have been calibrated """ for module in list(self._num_samples.keys()): name = self._module_names[module] @@ -154,6 +165,8 @@ def _maybe_onload_hessian(self, module: torch.nn.Module): self._hessians[module] = self._hessians[module].to(device="cpu") def on_finalize(self, state: State, **kwargs) -> bool: + self.compress_modules() # compress any remaining modules + self.remove_hooks() self._hessians = dict() self._num_samples = dict() diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 6cdf5fda80..7dfc3da57d 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -170,7 +170,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: state.data.calib, self.sequential_targets, self.ignore, - self, ) return True @@ -186,7 +185,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: state.model, state.data.calib, self.sequential_targets, - self, ) return True diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index 3b0eb9f584..93828e22ea 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -9,7 +9,7 @@ from loguru import logger from pydantic import PrivateAttr -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin from llmcompressor.modifiers.pruning.wanda.wanda_sparsify import ( @@ -74,6 +74,14 @@ def calibrate_module( args: Tuple[torch.Tensor, ...], _output: torch.Tensor, ): + """ + Calibration hook used to accumulate the row scalars of the input to the module + + :param module: module being calibrated + :param args: inputs to the module, the first element of which is the + cannonical input + :param _output: uncompressed module output, unused + """ # Assume that the first argument is the input inp = args[0] @@ -91,12 +99,14 @@ def calibrate_module( self._num_samples[module], ) - def on_sequential_batch_end(self): + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self.compress_modules() + + def compress_modules(self): """ - Sparsify modules - TODO: implement with event callback + Sparsify modules which have been calibrated """ - for module in list(self._num_samples.keys()): name = self._module_names[module] sparsity = self._module_sparsities[module] @@ -122,6 +132,8 @@ def on_sequential_batch_end(self): del self._num_samples[module] def on_finalize(self, state: State, **kwargs) -> bool: + self.compress_modules() # compress any remaining modules + self.remove_hooks() self._row_scalars = dict() self._num_samples = dict() diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 525ba1301b..8a3d672a6b 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -13,7 +13,7 @@ from loguru import logger from pydantic import Field, PrivateAttr, field_validator -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier, ModifierFactory from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization from llmcompressor.modifiers.quantization.gptq.gptq_quantize import ( @@ -236,7 +236,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: state.data.calib, self.sequential_targets, self.ignore, - self, ) return True @@ -257,7 +256,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: state.model, state.data.calib, self.sequential_targets, - self, ) return True @@ -281,6 +279,8 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: session state storing input model and calibration data """ + self.compress_modules() # compress any remaining modules + if self._quantization_modifier: self._quantization_modifier.finalize(state, **kwargs) @@ -298,13 +298,12 @@ def calibrate_module( _output: torch.Tensor, ): """ - Quantize a module's weight according to the GPTQ algorithm - - :param name: name of module being quantized - :param module: module being quantized - :param args: input arguments for module forward pass + Calibration hook used to accumulate the hessian of the input to the module - :return: total loss from applying weight quantization to this module + :param module: module being calibrated + :param args: inputs to the module, the first element of which is the + cannonical input + :param _output: uncompressed module output, unused """ # Assume that first argument is the input inp = args[0] @@ -326,10 +325,13 @@ def calibrate_module( self._num_samples[module], ) - def on_sequential_batch_end(self): + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self.compress_modules() + + def compress_modules(self): """ - Quantize modules. - TODO: implement with event callback + Quantize modules which have been calibrated """ for module in list(self._num_samples.keys()): name = self._module_names[module] diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index aa33171982..c3e9ef009e 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,11 +2,12 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.utils.offload import is_module_offloaded +from compressed_tensors.utils.offload import align_module_device from loguru import logger +from pydantic import ConfigDict from torch.nn import Module -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.smoothquant.utils import ( get_layer_mappings_from_architecture, @@ -105,7 +106,7 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - resolved_mappings_: Optional[List] = None + resolved_mappings_: Optional[List[SmoothQuantMapping]] = None scales_: Optional[Dict] = None def on_initialize(self, state: State, **kwargs) -> bool: @@ -139,6 +140,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True + def on_event(self, state: State, event: Event, **kwargs): + """ + Sparsify modules which have been calibrated with samples + """ + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self._apply_smoothing(state.model) + def on_finalize(self, state: State, **kwargs) -> bool: """ Clean up by clearing the scale and mapping data @@ -146,6 +154,9 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ + self.remove_hooks() + self._apply_smoothing(state.model) + if self.scales_ is not None: self.scales_.clear() if self.resolved_mappings_ is not None: @@ -166,7 +177,7 @@ def _infer_mappings_from_model( ) @handle_mapping_resolution_errors - def _resolve_mappings(self, model: Module) -> List: + def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: """ Transforms the list of activations to smooth and their corresponding weights into SmoothQuantMapping objects, resolving regular expressions. @@ -259,9 +270,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List): self.calibration_function, ) - # remove the hooks now that we are done calibrating - self.remove_hooks() - @torch.no_grad() def _apply_smoothing(self, model: Module): """ @@ -273,8 +281,11 @@ def _apply_smoothing(self, model: Module): This modifies the weights of the model in-place. """ - logger.info("Smoothing activation scales...") for mapping in self.resolved_mappings_: + if mapping.smooth_name not in self.scales_: + continue + logger.info(f"Smoothing with {mapping.smooth_name}") + activation_scales = ( # get dynamic range for each activation channel self.scales_[mapping.smooth_name].max_channel_vals - self.scales_[mapping.smooth_name].min_channel_vals @@ -289,22 +300,16 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - if module in balance_layers: - module.weight.mul_(scales.view(1, -1)) - elif module == smooth_layer: - if module.weight.ndim == 1: - module.weight.div_(scales) - else: - module.weight.div_(scales.view(-1, 1)) - if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales) - - if offloaded: - module._hf_hook.post_forward(module, None) + with align_module_device(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales) + else: + module.weight.div_(scales.view(-1, 1)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -315,6 +320,9 @@ def smooth(module): smooth(layer) smooth(smooth_layer) + # clear calibration data + del self.scales_[mapping.smooth_name] + def _calculate_smoothing_scales( self, balance_layers: List[Module], activation_scales: torch.Tensor ) -> List[float]: @@ -329,15 +337,9 @@ def _calculate_smoothing_scales( # get the channel-wise dynamic range for each layer to be balanced weight_scales = [] for layer in balance_layers: - offloaded = is_module_offloaded(layer) - if offloaded: - layer._hf_hook.pre_forward(layer) - - scale = layer.weight.abs().max(dim=0, keepdim=True)[0] - weight_scales.append(scale) - - if offloaded: - layer._hf_hook.post_forward(layer, None) + with align_module_device(layer): + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] + weight_scales.append(scale) weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] @@ -350,3 +352,5 @@ def _calculate_smoothing_scales( scales = torch.where(weight_scales > 0.0, scales, activation_scales) return scales + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 13a1c9454c..98d9d7bdf5 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -39,7 +39,3 @@ def run_pipeline( batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) model(**batch) - - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 9f8adbce4f..9973c589c9 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -1,9 +1,10 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import List import torch import torch.utils.data.dataloader import tqdm +from llmcompressor.core import LifecycleCallbacks from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -14,9 +15,6 @@ ) from llmcompressor.utils.helpers import calibration_forward_context -if TYPE_CHECKING: - from llmcompressor.modifiers import Modifier - __all__ = ["run_pipeline"] @@ -24,7 +22,6 @@ def run_pipeline( model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, sequential_targets: List[str], - callback_modifier: Optional["Modifier"] = None, ): """ Run a layer-wise sequential data pipeline according to the following steps: @@ -68,9 +65,8 @@ def run_pipeline( inputs = intermediates.fetch(batch_index) layer(**inputs) - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() + # trigger compression + LifecycleCallbacks.sequential_epoch_end() # this pass does not trigger modifier hooks # and is only used for capturing outputs from the newly compressed modules diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index c26205e3b4..927cbed332 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,18 +1,16 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import List import torch import torch.utils.data.dataloader import tqdm from compressed_tensors.utils import get_execution_device +from llmcompressor.core import LifecycleCallbacks from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.sequential.helpers import trace_subgraphs from llmcompressor.utils.helpers import calibration_forward_context -if TYPE_CHECKING: - from llmcompressor.modifiers import Modifier - __all__ = ["run_pipeline"] @@ -21,7 +19,6 @@ def run_pipeline( dataloader: torch.utils.data.DataLoader, sequential_targets: List[str], ignore: List[str], - callback_modifier: Optional["Modifier"] = None, ): """ Run a sequential data pipeline according to the following steps: @@ -69,9 +66,8 @@ def run_pipeline( inputs = intermediates.fetch(batch_index, subgraph.input_names) forward_function(model, **inputs) - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() + # trigger compression + LifecycleCallbacks.sequential_epoch_end() # this pass does not trigger modifier hooks # and is only used for capturing outputs from the newly compressed modules