Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ jobs:
- name: Running KV Cache Tests
if: (success() || failure()) && steps.install.outcome == 'success'
run: |
pytest -v tests/llmcompressor/transformers/kv_cache -k "not test_kv_cache_gptq_model_state_dict_attr"
pytest -v tests/llmcompressor/transformers/kv_cache
8 changes: 8 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,11 @@ class DatasetArguments(CustomDatasetArguments):
"will execute code present on the Hub on your local machine."
},
)
pipeline: Optional[str] = field(
default="independent",
metadata={
"help": "Calibration pipeline used to calibrate model"
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
"independent]"
},
)
9 changes: 9 additions & 0 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class EventType(Enum):
:param BATCH_START: Event type for the start of a batch.
:param LOSS_CALCULATED: Event type for when loss is calculated.
:param BATCH_END: Event type for the end of a batch.
:param CALIBRATION_EPOCH_START: Event type for the start of a calibration epoch.
:param SEQUENTIAL_EPOCH_END: Event type for the end of a layer calibration epoch,
specifically used by `src/llmcompressor/pipelines/sequential/pipeline.py`
:param CALIBRATION_EPOCH_END: Event type for the end of a calibration epoch.
:param OPTIM_PRE_STEP: Event type for pre-optimization step.
:param OPTIM_POST_STEP: Event type for post-optimization step.
"""
Expand All @@ -45,6 +49,11 @@ class EventType(Enum):
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"

# calibration lifecycle
CALIBRATION_EPOCH_START = "calibration_epoch_start"
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
CALIBRATION_EPOCH_END = "calibration_epoch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"
Expand Down
11 changes: 11 additions & 0 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,17 @@ def get_serialized_recipe(self) -> Optional[str]:

logger.warning("Recipe not found in session - it may have been reset")

def get_modifiers(self):
"""
Get all modifiers across all stages
"""
stage_modifiers = self.lifecycle.modifiers
return [
modifier
for stage_modifier in stage_modifiers
for modifier in stage_modifier.modifiers
] # noqa: E127

def _log_model_info(self):
# Log model level logs if cadence reached
current_index = self._lifecycle.global_step
Expand Down
35 changes: 33 additions & 2 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Optional
from typing import Any, Generator, Optional

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
Expand All @@ -21,7 +21,7 @@


@contextmanager
def create_session() -> CompressionSession:
def create_session() -> Generator[CompressionSession, None, None]:
"""
Context manager to create and yield a new session for sparsification.
This will set the active session to the new session for the duration
Expand Down Expand Up @@ -136,5 +136,36 @@ def batch_end(cls, **kwargs) -> ModifiedState:
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)

@classmethod
def calibration_epoch_start(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch start event for the active session during calibration. This event
should be called before calibration starts for one epoch

see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs)

@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch

This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)

@classmethod
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch end event for the active session during calibration. This event
should be called after the model has been calibrated for one epoch

see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)


callbacks = LifecycleCallbacks
23 changes: 14 additions & 9 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.pipelines.registry import CalibrationPipeline

__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -157,21 +158,25 @@ def apply_recipe_modifiers(
"""

session = active_session()
session.reset()

session_kwargs = dict(
# (Helen INFERENG-661): validate recipe modifiers before intialization
session.initialize(
model=self.model,
start=-1,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
recipe_stage=recipe_stage,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader, # only used by AWQModifier, remove once
# AWQModifier supports calibration pipelines
)

session.reset()
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)
user_pipeline = self.dataset_args.pipeline
modifiers = session.get_modifiers()
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
pipeline(self.model, calibration_dataloader, self.dataset_args)

session.finalize()


def oneshot(**kwargs) -> PreTrainedModel:
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):

self.initialized_ = self.on_initialize(state=state, **kwargs)

# trigger start
# trigger starts
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
if self.should_start(fake_start_event):
self.on_start(state, fake_start_event, **kwargs)
Expand All @@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs):
:param state: The current state of the model
:param kwargs: Additional arguments for finalizing the modifier
"""
if self.finalized_ or not self.initialized_:
return
if self.finalized_:
raise RuntimeError("cannot finalize a modifier twice")

if not self.initialized_:
raise RuntimeError("cannot finalize an uninitialized modifier")
Expand Down
24 changes: 17 additions & 7 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- sparsify_weight
- on_finalize
Expand Down Expand Up @@ -90,6 +87,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the hessian of the input to the module

:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -108,10 +113,9 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
Expand Down Expand Up @@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
# TODO: modify lifecycle to end on finalize
if not self.ended_:
self.on_end(state, None) # remove hooks

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._hessians = dict()
self._num_samples = dict()
self._module_names = dict()
Expand Down
91 changes: 38 additions & 53 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from loguru import logger
from pydantic import Field, PrivateAttr, field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers.modifier import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
Expand All @@ -24,7 +20,7 @@
)


class SparsityModifierMixin(HooksMixin):
class SparsityModifierMixin(Modifier):
# modifier arguments
sparsity: Optional[Union[float, List[float]]]
sparsity_profile: Optional[str] = None
Expand All @@ -42,6 +38,7 @@ class SparsityModifierMixin(HooksMixin):
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

@field_validator("sequential_update", mode="before")
Expand Down Expand Up @@ -97,6 +94,10 @@ def calibrate_module(
):
raise NotImplementedError()

@abstractmethod
def compress_modules(self):
raise NotImplementedError()

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
Expand All @@ -109,7 +110,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)
target_layers = get_layers(self.targets, model) # layers containing targets
self._target_layers = get_layers(
self.targets, model
) # layers containing targets

# infer layer sparsities
if self.sparsity_profile == "owl":
Expand All @@ -120,16 +123,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader)

# get layers and validate sparsity
if isinstance(self.sparsity, (list, dict)) and len(target_layers) != len(
if isinstance(self.sparsity, (list, dict)) and len(self._target_layers) != len(
self.sparsity
):
raise ValueError(
f"{self.__repr_name__} was initialized with {len(self.sparsity)} "
f"sparsities values, but model has {len(layers)} target layers"
)

return True

def on_start(self, state: State, event: Event, **kwargs):
self.started_ = True

# register hooks
for index, (layer_name, layer) in enumerate(target_layers.items()):
for index, (layer_name, layer) in enumerate(self._target_layers.items()):
if isinstance(self.sparsity, dict):
layer_sparsity = self.sparsity[layer_name]
elif isinstance(self.sparsity, list):
Expand Down Expand Up @@ -160,48 +168,23 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = dataloader.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib, self)
return True
def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.CALIBRATION_EPOCH_START:
if not self.started_:
self.on_start(state, None)

if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self.compress_modules()

if event.type_ == EventType.CALIBRATION_EPOCH_END:
self.compress_modules()

if not self.ended_:
self.on_end(state, None)

def on_end(self, state: State, event: Event, **kwargs):
self.ended_ = True
self.remove_hooks()

def _infer_sequential_targets(
self, model: torch.nn.Module
Expand Down Expand Up @@ -261,6 +244,8 @@ def _infer_owl_layer_sparsity(
return sparsities

def _get_activations(self, model, dataloader, nsamples=128) -> Dict[str, int]:
from llmcompressor.pipelines.basic import run_calibration

acts = defaultdict(int)

def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
Expand All @@ -275,7 +260,7 @@ def save_acts(_module, input: Union[Tuple[Any, ...], torch.Tensor], name: str):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name
)
with HooksMixin.disable_hooks(keep=hooks):
run_basic(model, dataloader)
run_calibration(model, dataloader)
self.remove_hooks(hooks)

return acts
Expand Down
Loading