diff --git a/.github/workflows/test-check-transformers.yaml b/.github/workflows/test-check-transformers.yaml index d16107add1..801b9524c2 100644 --- a/.github/workflows/test-check-transformers.yaml +++ b/.github/workflows/test-check-transformers.yaml @@ -103,4 +103,4 @@ jobs: - name: Running KV Cache Tests if: (success() || failure()) && steps.install.outcome == 'success' run: | - pytest -v tests/llmcompressor/transformers/kv_cache -k "not test_kv_cache_gptq_model_state_dict_attr" \ No newline at end of file + pytest -v tests/llmcompressor/transformers/kv_cache diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 31f7c73bb3..c76729be4e 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -171,3 +171,11 @@ class DatasetArguments(CustomDatasetArguments): "will execute code present on the Hub on your local machine." }, ) + pipeline: Optional[str] = field( + default="independent", + metadata={ + "help": "Calibration pipeline used to calibrate model" + "Options: ['basic', 'datafree', 'sequential', 'layer_sequential', " + "independent]" + }, + ) diff --git a/src/llmcompressor/core/events/event.py b/src/llmcompressor/core/events/event.py index 89eb780c85..a73903d4cf 100644 --- a/src/llmcompressor/core/events/event.py +++ b/src/llmcompressor/core/events/event.py @@ -32,6 +32,10 @@ class EventType(Enum): :param BATCH_START: Event type for the start of a batch. :param LOSS_CALCULATED: Event type for when loss is calculated. :param BATCH_END: Event type for the end of a batch. + :param CALIBRATION_EPOCH_START: Event type for the start of a calibration epoch. + :param SEQUENTIAL_EPOCH_END: Event type for the end of a layer calibration epoch, + specifically used by `src/llmcompressor/pipelines/sequential/pipeline.py` + :param CALIBRATION_EPOCH_END: Event type for the end of a calibration epoch. :param OPTIM_PRE_STEP: Event type for pre-optimization step. :param OPTIM_POST_STEP: Event type for post-optimization step. """ @@ -45,6 +49,11 @@ class EventType(Enum): LOSS_CALCULATED = "loss_calculated" BATCH_END = "batch_end" + # calibration lifecycle + CALIBRATION_EPOCH_START = "calibration_epoch_start" + SEQUENTIAL_EPOCH_END = "sequential_epoch_end" + CALIBRATION_EPOCH_END = "calibration_epoch_end" + # step lifecycle OPTIM_PRE_STEP = "optim_pre_step" OPTIM_POST_STEP = "optim_post_step" diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index 4f21c3f7ad..756b6181af 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -220,6 +220,17 @@ def get_serialized_recipe(self) -> Optional[str]: logger.warning("Recipe not found in session - it may have been reset") + def get_modifiers(self): + """ + Get all modifiers across all stages + """ + stage_modifiers = self.lifecycle.modifiers + return [ + modifier + for stage_modifier in stage_modifiers + for modifier in stage_modifier.modifiers + ] # noqa: E127 + def _log_model_info(self): # Log model level logs if cadence reached current_index = self._lifecycle.global_step diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index b280febbe7..47133d14ed 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -1,6 +1,6 @@ import threading from contextlib import contextmanager -from typing import Any, Optional +from typing import Any, Generator, Optional from llmcompressor.core.events import EventType from llmcompressor.core.session import CompressionSession @@ -21,7 +21,7 @@ @contextmanager -def create_session() -> CompressionSession: +def create_session() -> Generator[CompressionSession, None, None]: """ Context manager to create and yield a new session for sparsification. This will set the active session to the new session for the duration @@ -136,5 +136,36 @@ def batch_end(cls, **kwargs) -> ModifiedState: active_session()._log_model_info() return cls.event(EventType.BATCH_END, **kwargs) + @classmethod + def calibration_epoch_start(cls, **kwargs) -> ModifiedState: + """ + Invoke a epoch start event for the active session during calibration. This event + should be called before calibration starts for one epoch + + see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example + """ + return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs) + + @classmethod + def sequential_epoch_end(cls, **kwargs) -> ModifiedState: + """ + Invoke a sequential epoch end event for the active session. This event should be + called after one sequential layer has been calibrated/trained for one epoch + + This is called after a sequential layer has been calibrated with one batch, see + `src/llmcompressor/pipelines/sequential/pipeline.py` for usage example + """ + return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs) + + @classmethod + def calibration_epoch_end(cls, **kwargs) -> ModifiedState: + """ + Invoke a epoch end event for the active session during calibration. This event + should be called after the model has been calibrated for one epoch + + see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example + """ + return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs) + callbacks = LifecycleCallbacks diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index a6df0ee1c9..365032a221 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -7,6 +7,7 @@ from llmcompressor.core.session_functions import active_session from llmcompressor.datasets import get_calibration_dataloader from llmcompressor.entrypoints.utils import post_process, pre_process +from llmcompressor.pipelines.registry import CalibrationPipeline __all__ = ["Oneshot", "oneshot"] @@ -157,21 +158,25 @@ def apply_recipe_modifiers( """ session = active_session() + session.reset() - session_kwargs = dict( + # (Helen INFERENG-661): validate recipe modifiers before intialization + session.initialize( model=self.model, + start=-1, recipe=self.recipe, - recipe_args=self.recipe_args.recipe_args, - calib_data=calibration_dataloader, - start=-1, # oneshot-specific argument - copy_data=False, - min_tokens_per_module=getattr(self, "min_tokens_per_module", None), recipe_stage=recipe_stage, + recipe_args=self.recipe_args.recipe_args, + calib_data=calibration_dataloader, # only used by AWQModifier, remove once + # AWQModifier supports calibration pipelines ) - session.reset() - session.initialize(**session_kwargs) - session.finalize(**session_kwargs) + user_pipeline = self.dataset_args.pipeline + modifiers = session.get_modifiers() + pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline) + pipeline(self.model, calibration_dataloader, self.dataset_args) + + session.finalize() def oneshot(**kwargs) -> PreTrainedModel: diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 38911b5901..e91d881e8e 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs): self.initialized_ = self.on_initialize(state=state, **kwargs) - # trigger start + # trigger starts fake_start_event = Event(type_=EventType.BATCH_START, global_step=0) if self.should_start(fake_start_event): self.on_start(state, fake_start_event, **kwargs) @@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs): :param state: The current state of the model :param kwargs: Additional arguments for finalizing the modifier """ - if self.finalized_ or not self.initialized_: - return + if self.finalized_: + raise RuntimeError("cannot finalize a modifier twice") if not self.initialized_: raise RuntimeError("cannot finalize an uninitialized modifier") diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index ddffdecdc5..54e79461b1 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier): Lifecycle: - on_initialize - register_hook(module, calibrate_module, "forward") - - run_sequential / run_layer_sequential / run_basic - - make_empty_hessian - - accumulate_hessian - on_sequential_batch_end - sparsify_weight - on_finalize @@ -90,6 +87,14 @@ def calibrate_module( args: Tuple[torch.Tensor, ...], _output: torch.Tensor, ): + """ + Calibration hook used to accumulate the hessian of the input to the module + + :param module: module being calibrated + :param args: inputs to the module, the first element of which is the + cannonical input + :param _output: uncompressed module output, unused + """ # Assume that the first argument is the input inp = args[0] @@ -108,10 +113,9 @@ def calibrate_module( self._num_samples[module], ) - def on_sequential_batch_end(self): + def compress_modules(self): """ - Sparsify modules - TODO: implement with event callback + Sparsify modules which have been calibrated """ for module in list(self._num_samples.keys()): name = self._module_names[module] @@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module): self._hessians[module] = self._hessians[module].to(device="cpu") def on_finalize(self, state: State, **kwargs) -> bool: - self.remove_hooks() + # TODO: modify lifecycle to end on finalize + if not self.ended_: + self.on_end(state, None) # remove hooks + + if len(self._num_samples) > 0: + raise ValueError(f"Failed to compress {len(self._num_samples)} modules") + self._hessians = dict() self._num_samples = dict() self._module_names = dict() diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 6cdf5fda80..0e2f319c89 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -9,13 +9,9 @@ from loguru import logger from pydantic import Field, PrivateAttr, field_validator, model_validator -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State +from llmcompressor.modifiers.modifier import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.pipelines.basic import run_pipeline as run_basic -from llmcompressor.pipelines.layer_sequential import ( - run_pipeline as run_layer_sequential, -) -from llmcompressor.pipelines.sequential import run_pipeline as run_sequential from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, @@ -24,7 +20,7 @@ ) -class SparsityModifierMixin(HooksMixin): +class SparsityModifierMixin(Modifier): # modifier arguments sparsity: Optional[Union[float, List[float]]] sparsity_profile: Optional[str] = None @@ -42,6 +38,7 @@ class SparsityModifierMixin(HooksMixin): _prune_n: Optional[int] = PrivateAttr(default=None) _prune_m: Optional[int] = PrivateAttr(default=None) _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + _target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict) _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) @field_validator("sequential_update", mode="before") @@ -97,6 +94,10 @@ def calibrate_module( ): raise NotImplementedError() + @abstractmethod + def compress_modules(self): + raise NotImplementedError() + def on_initialize(self, state: "State", **kwargs) -> bool: """ Initialize and run the OBCQ algorithm on the current state @@ -109,7 +110,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer module and sequential targets self.sequential_targets = self._infer_sequential_targets(model) layers = get_layers(self.sequential_targets, model) - target_layers = get_layers(self.targets, model) # layers containing targets + self._target_layers = get_layers( + self.targets, model + ) # layers containing targets # infer layer sparsities if self.sparsity_profile == "owl": @@ -120,7 +123,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) # get layers and validate sparsity - if isinstance(self.sparsity, (list, dict)) and len(target_layers) != len( + if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len( self.sparsity ): raise ValueError( @@ -128,8 +131,13 @@ def on_initialize(self, state: "State", **kwargs) -> bool: f"sparsities values, but model has {len(layers)} target layers" ) + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + # register hooks - for index, (layer_name, layer) in enumerate(target_layers.items()): + for index, (layer_name, layer) in enumerate(self._target_layers.items()): if isinstance(self.sparsity, dict): layer_sparsity = self.sparsity[layer_name] elif isinstance(self.sparsity, list): @@ -160,48 +168,23 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self._module_sparsities[module] = layer_sparsity self.register_hook(module, self.calibrate_module, "forward") - # infer and run pipeline - model_name = state.model.__class__.__name__ - input_names = dataloader.dataset.column_names - unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError) - try: - run_sequential( - state.model, - state.data.calib, - self.sequential_targets, - self.ignore, - self, - ) - return True - - except Exception as exception: - if isinstance(exception, torch.fx.proxy.TraceError): - warnings.warn(f"Failed to trace {model_name} with inputs {input_names}") - if isinstance(exception, unfixable_errors): - raise exception - - warnings.warn("Falling back to layer_sequential pipeline") - try: - run_layer_sequential( - state.model, - state.data.calib, - self.sequential_targets, - self, - ) - return True - - except Exception as exception: - if isinstance(exception, TypeError): - warnings.warn(f"{model_name} fails layer-wise assumptions") - if isinstance(exception, unfixable_errors): - raise exception - - warnings.warn( - "Falling back to basic pipeline, which requires extra memory and " - "may result in decreased accuracy" - ) - run_basic(state.model, state.data.calib, self) - return True + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self.compress_modules() + + if event.type_ == EventType.CALIBRATION_EPOCH_END: + self.compress_modules() + + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + self.remove_hooks() def _infer_sequential_targets( self, model: torch.nn.Module @@ -261,6 +244,8 @@ def _infer_owl_layer_sparsity( return sparsities def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]: + from llmcompressor.pipelines.basic import run_calibration + acts = defaultdict(int) def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str): @@ -275,7 +260,7 @@ def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str): if isinstance(mod, torch.nn.Linear) and "lm_head" not in name ) with HooksMixin.disable_hooks(keep=hooks): - run_basic(model, dataloader) + run_calibration(model, dataloader) self.remove_hooks(hooks) return acts diff --git a/src/llmcompressor/modifiers/obcq/sgpt_sparsify.py b/src/llmcompressor/modifiers/obcq/sgpt_sparsify.py index 4d89f22496..c43014a72d 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_sparsify.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_sparsify.py @@ -3,9 +3,6 @@ import torch import transformers -from compressed_tensors.quantization.lifecycle.forward import forward_quantize - -from llmcompressor.utils import getattr_chain SGPT_PRECISION = torch.float32 @@ -69,8 +66,7 @@ def sparsify_weight( preserve_sparsity_mask: bool, ) -> torch.Tensor: """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. + Run pruning on the layer up to the target sparsity value. :param module: module with weight being sparsified :param hessian_dict: dictionary containing preaccumulated hessian for sparsification @@ -88,12 +84,6 @@ def sparsify_weight( H = hessians_dict[module] # unfortunately python does not have a `move` keyword del hessians_dict[module] # so we have to delete the original reference manually - # if this module is quantized, perform RTN quantization before sparsifying - args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(module, args_loc, None) - if weight_quant_args is not None: - W = forward_quantize(module, W, "weight", weight_quant_args) - # standardize shape and dtype if isinstance(module, torch.nn.Conv2d): W = W.flatten(1) @@ -217,9 +207,5 @@ def sparsify_weight( W.transpose_(0, 1) W = W.reshape(final_shape).to(final_dtype) - # perform RTN quantization - if weight_quant_args is not None: - W = forward_quantize(module, W, "weight", weight_quant_args) - loss = torch.sum(losses).item() return loss, W diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index c77cfff81f..bf194268b8 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -74,6 +74,14 @@ def calibrate_module( args: Tuple[torch.Tensor, ...], _output: torch.Tensor, ): + """ + Calibration hook used to accumulate the row scalars of the input to the module + + :param module: module being calibrated + :param args: inputs to the module, the first element of which is the + cannonical input + :param _output: uncompressed module output, unused + """ # Assume that the first argument is the input inp = args[0] @@ -91,12 +99,10 @@ def calibrate_module( self._num_samples[module], ) - def on_sequential_batch_end(self): + def compress_modules(self): """ - Sparsify modules - TODO: implement with event callback + Sparsify modules which have been calibrated """ - for module in list(self._num_samples.keys()): name = self._module_names[module] sparsity = self._module_sparsities[module] @@ -120,7 +126,13 @@ def on_sequential_batch_end(self): del self._num_samples[module] def on_finalize(self, state: State, **kwargs) -> bool: - self.remove_hooks() + # TODO: modify lifecycle to end on finalize + if not self.ended_: + self.on_end(state, None) # remove hooks + + if len(self._num_samples) > 0: + raise ValueError(f"Failed to compress {len(self._num_samples)} modules") + self._row_scalars = dict() self._num_samples = dict() self._module_names = dict() diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index d9c74a496a..53fccf37ee 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -13,7 +13,7 @@ from loguru import logger from pydantic import PrivateAttr, field_validator -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.gptq.gptq_quantize import ( accumulate_hessian, @@ -21,13 +21,7 @@ quantize_weight, ) from llmcompressor.modifiers.quantization.quantization import QuantizationMixin -from llmcompressor.pipelines.basic import run_pipeline as run_basic -from llmcompressor.pipelines.layer_sequential import ( - run_pipeline as run_layer_sequential, -) -from llmcompressor.pipelines.sequential import run_pipeline as run_sequential from llmcompressor.utils.metric_logging import CompressionLogger -from llmcompressor.utils.pytorch.module import get_no_split_params __all__ = ["GPTQModifier"] @@ -61,12 +55,11 @@ class GPTQModifier(Modifier, QuantizationMixin): Lifecycle: - on_initialize - - _build_quant_modifier - - register_hook(module, compress_module, "forward") - - run_sequential / run_layer_sequential / run_basic - - make_empty_hessian - - accumulate_hessian - - on_sequential_batch_end + - apply config to model + - on_start + - add activation calibration hooks + - add gptq weight calibration hooks + - on_sequential_epoch_end - quantize_weight - on_finalize - remove_hooks() @@ -109,7 +102,6 @@ class GPTQModifier(Modifier, QuantizationMixin): sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 - quantize: Union[bool, Dict] = True offload_hessians: bool = False # private variables @@ -138,16 +130,22 @@ def on_initialize(self, state: State, **kwargs) -> bool: if QuantizationMixin.has_config(self): QuantizationMixin.initialize_quantization(self, state.model) + # prepare module names + self._module_names = {m: name for name, m in state.model.named_modules()} + + return True + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # register quantization calibration hooks # assume quantization has been initialized by this modifier or one before it QuantizationMixin.start_calibration(self, state.model) # Unlike qmod, do not quantize as we calibrate # This choice does not seem to have a meaningful impact on accuracy state.model.apply(disable_quantization) - # prepare module names - self._module_names = {m: name for name, m in state.model.named_modules()} - - # register hooks + # register gptq hooks added_hook = False for module in state.model.modules(): if getattr_chain(module, "quantization_scheme.weights", None) is not None: @@ -160,68 +158,31 @@ def on_initialize(self, state: State, **kwargs) -> bool: if not added_hook: raise ValueError( - "GPTQModifier requires a quantization config be specified by this " - "modifier or a modifier preceding it" + "GPTQModifier requires a weight quantization config be specified by " + "this modifier or a modifier preceding it" ) - # infer sequential targets - if self.sequential_targets is None: - self.sequential_targets = get_no_split_params(state.model) - if isinstance(self.sequential_targets, str): - self.sequential_targets = [self.sequential_targets] - - # infer pipeline - model_name = state.model.__class__.__name__ - input_names = state.data.calib.dataset.column_names - unfixable_errors = ( - torch.OutOfMemoryError, - torch._C._LinAlgError, - KeyboardInterrupt, - ) - try: - run_sequential( - state.model, - state.data.calib, - self.sequential_targets, - self.ignore, - self, - ) - return True - - except Exception as exception: - if isinstance(exception, torch.fx.proxy.TraceError): - warnings.warn( - f"Failed to trace {model_name} with inputs {input_names}. For more " - "information on tracing with the sequential pipeline, see " - "https://github.com/vllm-project/llm-compressor/blob/main/" - "src/llmcompressor/transformers/tracing/GUIDE.md" - ) - if isinstance(exception, unfixable_errors): - raise exception - - warnings.warn("Falling back to layer_sequential pipeline") - try: - run_layer_sequential( - state.model, - state.data.calib, - self.sequential_targets, - self, - ) - return True - - except Exception as exception: - if isinstance(exception, TypeError): - warnings.warn(f"{model_name} fails layer-wise assumptions") - if isinstance(exception, unfixable_errors): - raise exception - - warnings.warn( - "Falling back to basic pipeline, which requires extra memory and " - "may result in decreased accuracy. Consider using " - "`offload_hessians=True`" - ) - run_basic(state.model, state.data.calib, self) - return True + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self.compress_modules() + + if event.type_ == EventType.CALIBRATION_EPOCH_END: + self.compress_modules() + + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + """ + Finish calibrating by removing observers and calibration hooks + """ + self.ended_ = True + QuantizationMixin.end_calibration(self, state.model) + self.remove_hooks() # remove gptq hooks def on_finalize(self, state: State, **kwargs) -> bool: """ @@ -229,12 +190,15 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: session state storing input model and calibration data """ + if not self.ended_: + self.on_end(state, None) + + if len(self._num_samples) > 0: + raise ValueError(f"Failed to compress {len(self._num_samples)} modules") + self._hessians = dict() self._num_samples = dict() - QuantizationMixin.end_calibration(self, state.model) - self.remove_hooks() # remove gptq hooks - return True def calibrate_module( @@ -244,13 +208,12 @@ def calibrate_module( _output: torch.Tensor, ): """ - Quantize a module's weight according to the GPTQ algorithm - - :param name: name of module being quantized - :param module: module being quantized - :param args: input arguments for module forward pass + Calibration hook used to accumulate the hessian of the input to the module - :return: total loss from applying weight quantization to this module + :param module: module being calibrated + :param args: inputs to the module, the first element of which is the + cannonical input + :param _output: uncompressed module output, unused """ # Assume that first argument is the input inp = args[0] @@ -272,10 +235,9 @@ def calibrate_module( self._num_samples[module], ) - def on_sequential_batch_end(self): + def compress_modules(self): """ - Quantize modules. - TODO: implement with event callback + Quantize modules which have been calibrated """ for module in list(self._num_samples.keys()): name = self._module_names[module] diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 3c309a074a..498290119a 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -1,13 +1,9 @@ -import torch import tqdm -from loguru import logger -from llmcompressor.core import Event, State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale from llmcompressor.modifiers.quantization.quantization.mixin import QuantizationMixin -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.utils.helpers import calibration_forward_context __all__ = ["QuantizationModifier"] @@ -56,72 +52,41 @@ def on_initialize(self, state: State, **kwargs) -> bool: """ if not QuantizationMixin.has_config(self): raise ValueError( - "QuantizationModifier requires that quantization fields to be specified" + "QuantizationModifier requires that quantization fields be specified" ) - QuantizationMixin.initialize_quantization(self, state.model) - # FUTURE: modify oneshot lifecycle to trigger on_start for on initialize - if self.calculate_start() == -1: # one shot - self.on_start(state) - return True - def on_start(self, state: State): + def on_start(self, state: State, event: Event, **kwargs): """ Begin calibrating activations and weights. Calibrate weights only once on start """ + self.started_ = True QuantizationMixin.start_calibration(self, state.model) modules = list(state.model.modules()) for module in tqdm.tqdm(modules, desc="Calibrating weights"): update_weight_zp_scale(module) - # FUTURE: below will be removed after pipeline extraction - if self.calculate_start() == -1: # one shot - self._calibrate_if_possible(state) + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + if event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) def on_end(self, state: State, event: Event, **kwargs): """ Finish calibrating by removing observers and calibration hooks """ + self.ended_ = True QuantizationMixin.end_calibration( self, state.model ) # keep quantization enabled def on_finalize(self, state: State, **kwargs) -> bool: - # TODO: modify lifecycle so modifiers end on finalize if not self.ended_: self.on_end(state, None) - - def _calibrate_if_possible(self, state: State): - model = state.model - calibration_dataloader = state.data.calib - config = QuantizationMixin.resolve_quantization_config(self) - - has_calibration_data = calibration_dataloader is not None - requires_calibration = config.requires_calibration_data() - if requires_calibration and not has_calibration_data: - raise ValueError( - "The provided quantization configuration requires calibration data " - "but none was provided. Calibration data is required for static " - "quantization of input or output activations." - ) - if not requires_calibration and has_calibration_data: - logger.info( - "Skipping QuantizationModifier calibration, it is not required for " - "the provided quantization config." - ) - return - - if not requires_calibration: - return - - self._calibrate(model, calibration_dataloader) - - def _calibrate(self, module: torch.nn.Module, data: torch.utils.data.DataLoader): - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info(f"Running {class_name} calibration with {len(data)} samples...") - - with calibration_forward_context(module): - run_calibration_forward(module, data) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index a08d570eb1..96f4acccf6 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -7,15 +7,13 @@ from pydantic import ConfigDict, Field from torch.nn import Module -from llmcompressor.core import State +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.smoothquant.utils import ( get_layer_mappings_from_architecture, handle_mapping_resolution_errors, ) -from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.utils.fsdp.helpers import get_fsdp_parent -from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( get_layers, get_matching_layer, @@ -134,21 +132,40 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.resolved_mappings_ = self._resolve_mappings(state.model) self.scales_ = {} - calibration_dataloader = state.data.calib + return True + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True self._setup_scale_hooks() - self._calibrate(state.model, calibration_dataloader) - self._apply_smoothing(state.model) - return True + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + self._apply_smoothing(state.model) + + if event.type_ == EventType.CALIBRATION_EPOCH_END: + self._apply_smoothing(state.model) + + if not self.ended_: + self.on_end(state, None) + + def on_end(self, state: State, event: Event, **kwargs): + self.ended_ = True + self.remove_hooks() # remove hooks def on_finalize(self, state: State, **kwargs) -> bool: """ Clean up by clearing the scale and mapping data - - :param state: unused - :return: True """ + if not self.ended_: + self.on_end(state, None) + + if len(self.scales_) > 0: + raise ValueError(f"Failed to compress {len(self.scales_)} modules") + if self.scales_ is not None: self.scales_.clear() if self.resolved_mappings_ is not None: @@ -237,34 +254,6 @@ def hook_fn(module, inp, out): layer = mapping.smooth_layer self.register_hook(layer, create_hook_fn(name), "forward") - @torch.no_grad() - def _calibrate(self, model: Module, calibration_dataloader: List): - """ - Catch the output dynamic ranges of each layer that will be smoothed by running - forward passes with calibration_dataloader - """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " - f"{len(calibration_dataloader)} samples..." - ) - if not calibration_dataloader: - raise ValueError( - "Calibration data loader not set, must populate the calib_data field of" - " CompressionSession to run the SmoothQuant modifier" - ) - - with calibration_forward_context(model): - run_calibration_forward( - model, - calibration_dataloader, - self.num_calibration_steps, - self.calibration_function, - ) - - # remove the hooks now that we are done calibrating - self.remove_hooks() - @torch.no_grad() def _apply_smoothing(self, model: Module): """ @@ -276,8 +265,11 @@ def _apply_smoothing(self, model: Module): This modifies the weights of the model in-place. """ - logger.info("Smoothing activation scales...") for mapping in self.resolved_mappings_: + if mapping.smooth_name not in self.scales_: + continue + logger.info(f"Smoothing with {mapping.smooth_name}") + activation_scales = ( # get dynamic range for each activation channel self.scales_[mapping.smooth_name].max_channel_vals - self.scales_[mapping.smooth_name].min_channel_vals @@ -312,6 +304,9 @@ def smooth(module): smooth(layer) smooth(smooth_layer) + # clear calibration data + del self.scales_[mapping.smooth_name] + def _calculate_smoothing_scales( self, balance_layers: List[Module], activation_scales: torch.Tensor ) -> List[float]: diff --git a/src/llmcompressor/modifiers/stage.py b/src/llmcompressor/modifiers/stage.py index 75a11ffc50..ae27fdbfe8 100644 --- a/src/llmcompressor/modifiers/stage.py +++ b/src/llmcompressor/modifiers/stage.py @@ -80,7 +80,8 @@ def initialize(self, state: "State", **kwargs): accelerator = kwargs.get("accelerator", None) for modifier in self.modifiers: - modifier.initialize(state, **kwargs) + if not modifier.initialized: + modifier.initialize(state, **kwargs) if accelerator: accelerator.wait_for_everyone() state.loggers.system.info(tag="stage", string="Modifiers initialized") diff --git a/src/llmcompressor/pipelines/__init__.py b/src/llmcompressor/pipelines/__init__.py index e69de29bb2..411280a998 100644 --- a/src/llmcompressor/pipelines/__init__.py +++ b/src/llmcompressor/pipelines/__init__.py @@ -0,0 +1,8 @@ +# flake8: noqa +# populate registry +from .basic import * +from .data_free import * +from .independent import * +from .layer_sequential import * +from .registry import * +from .sequential import * diff --git a/src/llmcompressor/pipelines/basic/__init__.py b/src/llmcompressor/pipelines/basic/__init__.py index fc60475ca8..7c726f6c40 100644 --- a/src/llmcompressor/pipelines/basic/__init__.py +++ b/src/llmcompressor/pipelines/basic/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -from .pipeline import run_pipeline +from .pipeline import * diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 13a1c9454c..15b94786a1 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -1,45 +1,55 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Union import torch -import torch.utils.data.dataloader import tqdm from compressed_tensors.utils import get_execution_device +from torch.utils.data.dataloader import DataLoader +from llmcompressor.core import LifecycleCallbacks from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch +from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context if TYPE_CHECKING: - from llmcompressor.modifiers import Modifier - -__all__ = ["run_pipeline"] - - -def run_pipeline( - model: torch.nn.Module, - dataloader: torch.utils.data.DataLoader, - callback_modifier: Optional["Modifier"] = None, -): - """ - Run a basic data pipeline. - - Batches are fetched from the data loader and are used to perform forward passes - through the model. This pipeline is typically used for basic model calibration - and, unlike the sequential pipelines, does not propagate compression error when - used to calibrate model compression - - :param model: model being calibrated - :param dataloader: loads data for calibration - :param callback_modifier: Temporary HACK which should be replaced by event callback - """ - model_device = get_execution_device(model) - - with calibration_forward_context(model): - for batch in tqdm.tqdm(dataloader, desc="Calibrating"): - batch = apply_pad_mask_to_batch(batch) - batch = tensors_to_device(batch, model_device) - model(**batch) - - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() + from llmcompressor.args.dataset_arguments import DatasetArguments + +__all__ = ["BasicPipeline", "run_calibration"] + + +@CalibrationPipeline.register("basic") +class BasicPipeline(CalibrationPipeline): + @staticmethod + def __call__( + model: torch.nn.Module, + dataloader: DataLoader, + dataset_args: Union["DatasetArguments", None], + ): + """ + Run a basic data pipeline. + + Batches are fetched from the data loader and are used to perform forward passes + through the model. This pipeline is typically used for basic model calibration + and, unlike the sequential pipelines, does not propagate compression error when + used to calibrate model compression + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param dataset_args: dataset arguments relevant to pipelines + """ + model_device = get_execution_device(model) + + LifecycleCallbacks.calibration_epoch_start() + + with calibration_forward_context(model): + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): + batch = apply_pad_mask_to_batch(batch) + batch = tensors_to_device(batch, model_device) + model(**batch) + + LifecycleCallbacks.calibration_epoch_end() + + +def run_calibration(model: torch.nn.Module, dataloader: DataLoader): + pipeline = BasicPipeline() + pipeline(model, dataloader, None) diff --git a/src/llmcompressor/pipelines/data_free/__init__.py b/src/llmcompressor/pipelines/data_free/__init__.py new file mode 100644 index 0000000000..7c726f6c40 --- /dev/null +++ b/src/llmcompressor/pipelines/data_free/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline import * diff --git a/src/llmcompressor/pipelines/data_free/pipeline.py b/src/llmcompressor/pipelines/data_free/pipeline.py new file mode 100644 index 0000000000..587f7ca693 --- /dev/null +++ b/src/llmcompressor/pipelines/data_free/pipeline.py @@ -0,0 +1,31 @@ +from typing import TYPE_CHECKING, Optional + +import torch +from torch.utils.data.dataloader import DataLoader + +from llmcompressor.core.session_functions import LifecycleCallbacks +from llmcompressor.pipelines.registry import CalibrationPipeline + +if TYPE_CHECKING: + from llmcompressor.args.dataset_arguments import DatasetArguments + +__all__ = ["DataFreePipeline"] + + +@CalibrationPipeline.register("datafree") +class DataFreePipeline(CalibrationPipeline): + @staticmethod + def __call__( + model: torch.nn.Module, + dataloader: Optional[DataLoader], + dataset_args: "DatasetArguments", + ): + """ + A pipeline for data-free calibration + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param dataset_args: dataset arguments relevant to pipelines + """ + LifecycleCallbacks.calibration_epoch_start() + LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/independent/__init__.py b/src/llmcompressor/pipelines/independent/__init__.py new file mode 100644 index 0000000000..7c726f6c40 --- /dev/null +++ b/src/llmcompressor/pipelines/independent/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from .pipeline import * diff --git a/src/llmcompressor/pipelines/independent/pipeline.py b/src/llmcompressor/pipelines/independent/pipeline.py new file mode 100644 index 0000000000..797cf27998 --- /dev/null +++ b/src/llmcompressor/pipelines/independent/pipeline.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +import torch +from loguru import logger +from torch.utils.data.dataloader import DataLoader + +from llmcompressor.core import active_session +from llmcompressor.modifiers.stage import StageModifiers +from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.utils.helpers import patch_attr + +if TYPE_CHECKING: + from llmcompressor.args.dataset_arguments import DatasetArguments + +__all__ = ["IndependentPipeline"] + + +@CalibrationPipeline.register("independent") +class IndependentPipeline(CalibrationPipeline): + @staticmethod + def __call__( + model: torch.nn.Module, + dataloader: DataLoader, + dataset_args: "DatasetArguments", + ): + """ + Data pipeline where each modifier is assigned its own calibration epoch and data + pipeline + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param dataset_args: dataset arguments relevant to pipelines + """ + _logger = logger.patch(lambda r: r.update(function="IndependentPipeline")) + + session = active_session() + modifiers = session.get_modifiers() + with patch_attr(session.lifecycle, "modifiers", None): + for index, modifier in enumerate(modifiers): + mod_type = str(type(modifier).__name__) + session.lifecycle.modifiers = [ + StageModifiers(modifiers=[modifier], group=mod_type, index=index) + ] + + pipeline = CalibrationPipeline.from_modifiers([modifier]) + pipeline_name = pipeline.__class__.__name__ + _logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`") + + pipeline(model, dataloader, dataset_args) + + # restore modifiers on exit so model can be compressed based on recipe diff --git a/src/llmcompressor/pipelines/layer_sequential/__init__.py b/src/llmcompressor/pipelines/layer_sequential/__init__.py index fc60475ca8..7c726f6c40 100644 --- a/src/llmcompressor/pipelines/layer_sequential/__init__.py +++ b/src/llmcompressor/pipelines/layer_sequential/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -from .pipeline import run_pipeline +from .pipeline import * diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 6f7ec81b66..79e37ff6eb 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -1,9 +1,10 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import torch -import torch.utils.data.dataloader import tqdm +from torch.utils.data.dataloader import DataLoader +from llmcompressor.core import LifecycleCallbacks, active_session from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -12,77 +13,90 @@ maybe_inject_pos_embeddings, to_next_layer_kwargs, ) +from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.pipelines.sequential.helpers import get_targets_from_modifiers from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context if TYPE_CHECKING: - from llmcompressor.modifiers import Modifier - -__all__ = ["run_pipeline"] - - -def run_pipeline( - model: torch.nn.Module, - dataloader: torch.utils.data.DataLoader, - sequential_targets: List[str], - callback_modifier: Optional["Modifier"] = None, -): - """ - Run a layer-wise sequential data pipeline according to the following steps: - - 1. Layers are identified according to `sequential_targets` - 2. A hook is attached to the first layer. This hook raises an exception which is - then caught and used to capture the input arguments to the first layer - 3. The inputs to the first layer are used to calibrate the first layer, and the - output of the previous layer is used as inputs to calibrate the next layer - - This pipeline requires that the model have distinct layers defined in its - architecture and that the outputs of the previous layer are exactly the inputs - to the next layer. This is violated by encoder-decoder architectures among others. - - If your model architecture violates these assumptions, consider using the sequential - pipeline (see llmcompressor.pipelines.sequential). Architectures which are known to - fail these assumptions include GPT-J and most vision language models - - :param model: model being calibrated - :param dataloader: loads data for calibration - :param sequential_targets: patterns which match to the layer modules of the model - :param callback_modifier: Temporary HACK which should be replaced by event callback - """ - # find layers - layers = match_modules(model, sequential_targets) - - with calibration_forward_context(model), DisableQuantization(model): - # prepare intermediates cache - intermediates: IntermediatesCache = capture_first_layer_intermediates( - model, layers[0], dataloader - ) - - num_layers = len(layers) - for layer_index, layer in enumerate(layers): - # prepare tqdm description texts - calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" - prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" - - # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): - inputs = intermediates.fetch(batch_index) - layer(**inputs) - - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() - - # this pass does not trigger modifier hooks - # and is only used for capturing outputs from the newly compressed modules - with HooksMixin.disable_hooks(): - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): - inputs = intermediates.fetch(batch_index) - output = layer(**inputs) - - if layer_index < num_layers - 1: - next_layer = layers[layer_index + 1] - output = to_next_layer_kwargs(output, next_layer) - output = maybe_inject_pos_embeddings(output, next_layer, inputs) - - intermediates.delete(batch_index) - intermediates.update(batch_index, output) + from llmcompressor.args.dataset_arguments import DatasetArguments + + +__all__ = ["LayerSequentialPipeline"] + + +@CalibrationPipeline.register("layer_sequential") +class LayerSequentialPipeline(CalibrationPipeline): + @staticmethod + def __call__( + model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments" + ): + """ + Run a layer-wise sequential data pipeline according to the following steps: + + 1. Layers are identified according to `sequential_targets` + 2. A hook is attached to the first layer. This hook raises an exception which is + then caught and used to capture the input arguments to the first layer + 3. The inputs to the first layer are used to calibrate the first layer, and the + output of the previous layer is used as inputs to calibrate the next layer + + This pipeline requires that the model have distinct layers defined in its + architecture and that the outputs of the previous layer are exactly the inputs + to the next layer. This is violated by encoder-decoder architectures, among + others. + + If your model architecture violates these assumptions, consider using the + sequential pipeline (see llmcompressor.pipelines.sequential). Architectures + which are known to fail these assumptions include GPT-J and most vision models + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param dataset_args: dataset arguments relevant to pipelines + """ + session = active_session() + + # find layers + modifiers = session.get_modifiers() + sequential_targets, _ = get_targets_from_modifiers(modifiers, model) + layers = match_modules(model, sequential_targets) + + LifecycleCallbacks.calibration_epoch_start() + + with calibration_forward_context(model), DisableQuantization(model): + # prepare intermediates cache + intermediates: IntermediatesCache = capture_first_layer_intermediates( + model, layers[0], dataloader + ) + + num_layers = len(layers) + for layer_index, layer in enumerate(layers): + # prepare tqdm description texts + calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" + prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" + + # do an preliminary pass to trigger modifier hooks + for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_idx) + layer(**inputs) + + # trigger compression + LifecycleCallbacks.sequential_epoch_end() + + # this pass does not trigger modifier hooks + # and is only used for capturing outputs from newly compressed modules + with HooksMixin.disable_hooks(): + for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): + inputs = intermediates.fetch(batch_idx) + output = layer(**inputs) + + if layer_index < num_layers - 1: + next_layer = layers[layer_index + 1] + output = to_next_layer_kwargs(output, next_layer) + output = maybe_inject_pos_embeddings( + output, next_layer, inputs + ) + + intermediates.delete(batch_idx) + intermediates.update(batch_idx, output) + + # redudant, finish any remaining compression + LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/registry.py b/src/llmcompressor/pipelines/registry.py new file mode 100644 index 0000000000..3501df9751 --- /dev/null +++ b/src/llmcompressor/pipelines/registry.py @@ -0,0 +1,102 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional + +import torch +from compressed_tensors.registry import RegistryMixin, standardize_lookup_name +from loguru import logger +from torch.utils.data.dataloader import DataLoader + +from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.obcq.sgpt_mixin import SparsityModifierMixin +from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationMixin +from llmcompressor.modifiers.smoothquant import SmoothQuantModifier + +if TYPE_CHECKING: + from llmcompressor.args.dataset_arguments import DatasetArguments + +__all__ = ["CalibrationPipeline"] + +SEQUENTIAL_MODIFIERS = (GPTQModifier, SparsityModifierMixin) + + +class CalibrationPipeline(ABC, RegistryMixin): + @staticmethod + @abstractmethod + def __call__( + model: torch.nn.Module, + dataloader: DataLoader, + dataset_args: "DatasetArguments", + ): + raise NotImplementedError() + + @classmethod + def from_modifiers( + cls, modifiers: List[Modifier], user: Optional[str] = None + ) -> "CalibrationPipeline": + """ + Infer which calibration pipeline to use based on the available modifiers and + any user specifications + + :param modifiers: modifiers to apply to model + :param user: pipeline name passed by user + :return: CalibrationPipeline instance to be called with data (if not datafree) + """ + user = standardize_lookup_name(user) if user else None + inferred = standardize_lookup_name(cls._validate_infer_pipeline(modifiers)) + independent = standardize_lookup_name("independent") + + if user == independent: + inferred = independent + + if user is not None and user != inferred: + logger.warning( + f"Calibration pipeline is set to `{user}`, but it is recommended to " + f"use `{inferred}`" + ) + + pipeline = user or inferred + return cls.load_from_registry(pipeline) + + @staticmethod + def _validate_infer_pipeline(modifiers: List[Modifier]) -> str: + if any(isinstance(modifier, AWQModifier) for modifier in modifiers): + if len(modifiers) > 1: + logger.warning( + "AWQ does not currently support sharing a data pipeline with other " + "modifiers. Inferring `independent` calibration pipeline" + ) + return "independent" + return "datafree" + + if any(isinstance(modifier, SEQUENTIAL_MODIFIERS) for modifier in modifiers): + return "sequential" + + active_qmods = _get_active_quant_modifiers(modifiers) + if len(active_qmods) > 1: + raise ValueError( + f"Recipe contains more than one active quantization config " + f"({active_qmods}). These configs may be conflicting, Please modify " + "your recipe to use at most one quantization config" + ) + + if len(active_qmods) == 1: + quant_modifier = active_qmods[0] + config = quant_modifier.resolve_quantization_config() + if config.requires_calibration_data(): + return "basic" + else: + return "datafree" + + if any(isinstance(modifier, SmoothQuantModifier) for modifier in modifiers): + return "basic" + + return "datafree" + + +def _get_active_quant_modifiers(modifiers: List[Modifier]) -> List[QuantizationMixin]: + return [ + modifier + for modifier in modifiers + if isinstance(modifier, QuantizationMixin) and modifier.has_config() + ] diff --git a/src/llmcompressor/pipelines/sequential/__init__.py b/src/llmcompressor/pipelines/sequential/__init__.py index fc60475ca8..d96ee6987c 100644 --- a/src/llmcompressor/pipelines/sequential/__init__.py +++ b/src/llmcompressor/pipelines/sequential/__init__.py @@ -1,2 +1,3 @@ # flake8: noqa -from .pipeline import run_pipeline +from .helpers import get_targets_from_modifiers +from .pipeline import * diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index d96352a66c..a1a716e856 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -1,7 +1,7 @@ import inspect from collections import deque from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches @@ -14,10 +14,12 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils.fx import HFTracer +from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.helpers import calibration_forward_context, patch_attr +from llmcompressor.utils.pytorch.module import get_no_split_params -__all__ = ["trace_subgraphs", "Subgraph"] +__all__ = ["trace_subgraphs", "Subgraph", "get_targets_from_modifiers"] @dataclass @@ -403,6 +405,46 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]: ) +def get_targets_from_modifiers( + modifiers: List[Modifier], model: PreTrainedModel +) -> Tuple[List[str], List[str]]: + """ + Infer sequential targets and ignore list from modifiers list + + :param model: model being calibrated + :param modifiers: list of modifiers being applied during calibration + :return: list of sequential targets and list of modules to ignore for tracing + """ + # avoid circular import + from llmcompressor.pipelines.registry import SEQUENTIAL_MODIFIERS + + sequential_modifiers = [ + modifier for modifier in modifiers if isinstance(modifier, SEQUENTIAL_MODIFIERS) + ] + + if len(sequential_modifiers) >= 2: + types = [type(modifier) for modifier in sequential_modifiers] + logger.warning( + "Cannot infer sequential targets from multiple sequential modifiers " + f"({types}). Defaulting to {types[0]}" + ) + elif len(sequential_modifiers) <= 0: + types = [type(modifier) for modifier in modifiers] + raise ValueError(f"Cannot infer sequential targets from list of {types}") + + modifier = sequential_modifiers[0] + + # infer sequential targets + if modifier.sequential_targets is None: + sequential_targets = get_no_split_params(model) + elif isinstance(modifier.sequential_targets, str): + sequential_targets = [modifier.sequential_targets] + else: + sequential_targets = modifier.sequential_targets + + return sequential_targets, modifier.ignore + + def add_line_numbers(text: str) -> str: lines = text.splitlines() numbered_lines = [f"{i + 1} {line}" for i, line in enumerate(lines)] diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index c4ad9035df..80eb3739d6 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,83 +1,95 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import torch -import torch.utils.data.dataloader import tqdm from compressed_tensors.utils import get_execution_device +from torch.utils.data.dataloader import DataLoader +from llmcompressor.core import LifecycleCallbacks, active_session from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache -from llmcompressor.pipelines.sequential.helpers import trace_subgraphs +from llmcompressor.pipelines.registry import CalibrationPipeline +from llmcompressor.pipelines.sequential.helpers import ( + get_targets_from_modifiers, + trace_subgraphs, +) from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context if TYPE_CHECKING: - from llmcompressor.modifiers import Modifier - -__all__ = ["run_pipeline"] - - -def run_pipeline( - model: torch.nn.Module, - dataloader: torch.utils.data.DataLoader, - sequential_targets: List[str], - ignore: List[str], - callback_modifier: Optional["Modifier"] = None, -): - """ - Run a sequential data pipeline according to the following steps: - - 1. The model is partitioned into subgraphs according to `sequential_targets` - 2. Data passes through each subgraph sequentially. Data is passed through each - subgraph twice, once to trigger calibration hooks, then a second time in order - to capture activations after quantization has occurred through the hooks. - 3. The intermediate activations between each subgraph are cached and offloaded to - the cpu between each batch in order to save memory - - This pipeline requires that the model be traceable with respect to data from the - data loader. This may be an issue for vision language models with vision datasets, - due to specialized input processing in the model. - - In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model - can be made traceable by wrapping the untraceable functions (see - llmcompressor.transformers.tracing) - - :param model: model being calibrated - :param dataloader: loads data for calibration - :param sequential_targets: patterns which match to the layer modules of the model - :param ignore: TODO: unused, in the future will specify functions and methods to - skip during tracing - """ - # trace subgraphs - sample_input = next(iter(dataloader)) - subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) - - with calibration_forward_context(model), DisableQuantization(model): - # prepare intermediates cache - model_device = get_execution_device(model) - intermediates = IntermediatesCache.from_dataloader(dataloader, model_device) - - num_subgraphs = len(subgraphs) - for subgraph_index, subgraph in enumerate(subgraphs): - # prepare tqdm description texts - calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" - prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" - - # do an preliminary pass to trigger modifier hooks - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): - inputs = intermediates.fetch(batch_index, subgraph.input_names) - subgraph.forward(model, **inputs) - - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() - - # this pass does not trigger modifier hooks - # and is only used for capturing outputs from the newly compressed modules - with HooksMixin.disable_hooks(): - for batch_index in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): - inputs = intermediates.fetch(batch_index, subgraph.input_names) - output = subgraph.forward(model, **inputs) - - if subgraph_index < num_subgraphs - 1: - intermediates.update(batch_index, output) - intermediates.delete(batch_index, subgraph.consumed_names) + from llmcompressor.args.dataset_arguments import DatasetArguments + +__all__ = ["SequentialPipeline"] + + +@CalibrationPipeline.register("sequential") +class SequentialPipeline(CalibrationPipeline): + @staticmethod + def __call__( + model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments" + ): + """ + Run a sequential data pipeline according to the following steps: + + 1. The model is partitioned into subgraphs according to `sequential_targets` + 2. Data passes through each subgraph sequentially. Data is passed through each + subgraph twice, once to trigger calibration hooks, then a second time in + order to capture activations after quantization has occurred through hooks. + 3. The intermediate activations between each subgraph are cached and offloaded + to the cpu between each batch in order to save memory + + This pipeline requires that the model be traceable with respect to data from the + data loader. This may be an issue for vision models with vision datasets, due + to specialized input processing in the model. + + In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A + model can be made traceable by wrapping the untraceable functions (see + llmcompressor.transformers.tracing) + + :param model: model being calibrated + :param dataloader: loads data for calibration + :param dataset_args: dataset arguments relevant to pipelines + """ + session = active_session() + + # infer sequential targets + modifiers = session.get_modifiers() + sequential_targets, ignore = get_targets_from_modifiers(modifiers, model) + + # trace subgraphs + sample_input = next(iter(dataloader)) + subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore) + + LifecycleCallbacks.calibration_epoch_start() + + with calibration_forward_context(model), DisableQuantization(model): + # prepare intermediates cache + model_device = get_execution_device(model) + intermediates = IntermediatesCache.from_dataloader(dataloader, model_device) + + num_subgraphs = len(subgraphs) + for subgraph_index, subgraph in enumerate(subgraphs): + # prepare tqdm description texts + calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating" + prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating" + + # do an preliminary pass to trigger modifier hooks + for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): + inputs = intermediates.fetch(batch_idx, subgraph.input_names) + subgraph.forward(model, **inputs) + + # trigger compression + LifecycleCallbacks.sequential_epoch_end() + + # this pass does not trigger modifier hooks + # and is only used for capturing outputs from newly compressed modules + with HooksMixin.disable_hooks(): + for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc): + inputs = intermediates.fetch(batch_idx, subgraph.input_names) + output = subgraph.forward(model, **inputs) + + if subgraph_index < num_subgraphs - 1: + intermediates.update(batch_idx, output) + intermediates.delete(batch_idx, subgraph.consumed_names) + + # redudant, finish any remaining compression + LifecycleCallbacks.calibration_epoch_end() diff --git a/tests/e2e/vLLM/recipes/kv_cache/gptq.yaml b/tests/e2e/vLLM/recipes/kv_cache/gptq.yaml index 8c76de33ac..1bb37fd578 100644 --- a/tests/e2e/vLLM/recipes/kv_cache/gptq.yaml +++ b/tests/e2e/vLLM/recipes/kv_cache/gptq.yaml @@ -1,8 +1,5 @@ quant_stage: quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - {num_bits: 8, type: float, symmetric: true, strategy: tensor} GPTQModifier: sequential_update: false ignore: ["lm_head"] @@ -15,3 +12,5 @@ quant_stage: strategy: "channel" actorder: False targets: ["Linear"] + kv_cache_scheme: + {num_bits: 8, type: float, symmetric: true, strategy: tensor} \ No newline at end of file diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 14fd7dcb89..0e180e8f00 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -51,7 +51,8 @@ def test_successful_layerwise_recipe(self): sparsity=sparsities, block_size=128, targets=targets ) testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) - modifier.initialize(testing_harness.get_state()) # falls back to basic pipeline + modifier.initialize(testing_harness.get_state()) + modifier.on_start(testing_harness.get_state(), None) model = testing_harness.state.model num_hooks = len(modifier._hooks) @@ -69,6 +70,7 @@ def test_create_default_quant_modifier(self): testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) modifier.initialize(testing_harness.get_state()) + modifier.on_start(testing_harness.get_state(), None) model = testing_harness.state.model for module in model.modules(): diff --git a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py index 1e72eb7183..5110f07629 100644 --- a/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py +++ b/tests/llmcompressor/transformers/kv_cache/test_kv_cache.py @@ -7,6 +7,7 @@ from compressed_tensors.quantization.utils.helpers import iter_named_quantizable_modules from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.utils.quantization_config import CompressedTensorsConfig from llmcompressor import oneshot from llmcompressor.core import reset_session @@ -214,13 +215,6 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path): recipe = """ quant_stage: quant_modifiers: - QuantizationModifier: - kv_cache_scheme: - num_bits: {num_bits} - type: {_type} - strategy: {strategy} - dynamic: {dynamic} - symmetric: {symmetric} GPTQModifier: ignore: ["lm_head"] config_groups: @@ -232,12 +226,24 @@ def test_kv_cache_gptq_model_state_dict_attr(kv_cache_fixture, tmp_path): strategy: "channel" actorder: False targets: ["Linear"] + kv_cache_scheme: + num_bits: {num_bits} + type: {_type} + strategy: {strategy} + dynamic: {dynamic} + symmetric: {symmetric} """ output_dir, _ = next(kv_cache_fixture(recipe, tmp_path)) with init_empty_weights(): - model = AutoModelForCausalLM.from_pretrained(output_dir) + # TODO: There is a bug in `apply_quantization_config` which means that, if using + # CompressedLinears, the compression status is inferred to `compressed` and + # therefore the attention kvcache parameters never undergo initializations + model = AutoModelForCausalLM.from_pretrained( + output_dir, + quantization_config=CompressedTensorsConfig(run_compressed=False), + ) counts = 0 for name, submodule in iter_named_quantizable_modules( diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py index f87cae28d2..130e811ddd 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_lm_head.py @@ -44,7 +44,8 @@ def test_no_lm_head_target(self): state = State() state.update(model=self.model, device=self.device, calib_data=self.dataloader) - modifier.on_initialize(state) + modifier.initialize(state) + modifier.on_start(state, None) assert len(self.model.lm_head._forward_hooks) <= 0 @@ -56,7 +57,8 @@ def test_lm_head_target(self): state = State() state.update(model=self.model, device=self.device, calib_data=self.dataloader) - modifier.on_initialize(state) + modifier.initialize(state) + modifier.on_start(state, None) assert len(self.model.lm_head._forward_hooks) == 1