Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ jobs:
- name: Running KV Cache Tests
if: (success() || failure()) && steps.install.outcome == 'success'
run: |
pytest -v tests/llmcompressor/transformers/kv_cache -k "not test_kv_cache_gptq_model_state_dict_attr"
pytest -v tests/llmcompressor/transformers/kv_cache
8 changes: 8 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,11 @@ class DatasetArguments(CustomDatasetArguments):
"will execute code present on the Hub on your local machine."
},
)
pipeline: Optional[str] = field(
default="independent",
metadata={
"help": "Calibration pipeline used to calibrate model"
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
"independent]"
},
)
8 changes: 8 additions & 0 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class EventType(Enum):
:param BATCH_START: Event type for the start of a batch.
:param LOSS_CALCULATED: Event type for when loss is calculated.
:param BATCH_END: Event type for the end of a batch.
:param CALIBRATION_EPOCH_START: Event type for the start of a calibration epoch.
:param SEQUENTIAL_EPOCH_END: Event type for the end of a layer calibration epoch
:param CALIBRATION_EPOCH_END: Event type for the end of a calibration epoch.
:param OPTIM_PRE_STEP: Event type for pre-optimization step.
:param OPTIM_POST_STEP: Event type for post-optimization step.
"""
Expand All @@ -45,6 +48,11 @@ class EventType(Enum):
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"

# calibration lifecycle (TODO: support batched calibration lifecycles)
CALIBRATION_EPOCH_START = "calibration_epoch_start"
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
CALIBRATION_EPOCH_END = "calibration_epoch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"
Expand Down
5 changes: 3 additions & 2 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def initialize(
return

logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
if not (recipe is recipe_stage is recipe_args is None):
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()

mod_data = []
Expand Down
8 changes: 8 additions & 0 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def get_serialized_recipe(self) -> Optional[str]:

logger.warning("Recipe not found in session - it may have been reset")

def get_modifiers(self):
stage_modifiers = self.lifecycle.modifiers
return [
modifier
for stage_modifier in stage_modifiers
for modifier in stage_modifier.modifiers
] # noqa: E127

def _log_model_info(self):
# Log model level logs if cadence reached
current_index = self._lifecycle.global_step
Expand Down
35 changes: 33 additions & 2 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Optional
from typing import Any, Generator, Optional

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
Expand All @@ -21,7 +21,7 @@


@contextmanager
def create_session() -> CompressionSession:
def create_session() -> Generator[CompressionSession, None, None]:
"""
Context manager to create and yield a new session for sparsification.
This will set the active session to the new session for the duration
Expand Down Expand Up @@ -136,5 +136,36 @@ def batch_end(cls, **kwargs) -> ModifiedState:
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)

@classmethod
def calibration_epoch_start(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch start event for the active session during calibration. This event
should be called before calibration starts for one epoch

see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs)

@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch

This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)

@classmethod
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch end event for the active session during calibration. This event
should be called after the model has been calibrated for one epoch

see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)


callbacks = LifecycleCallbacks
21 changes: 12 additions & 9 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.pipelines.registry import CalibrationPipeline

__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -157,21 +158,23 @@ def apply_recipe_modifiers(
"""

session = active_session()
session.reset()

session_kwargs = dict(
session.initialize(
model=self.model,
start=-1,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
recipe_stage=recipe_stage,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader, # TODO: only used by AWQ modifier
)

session.reset()
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)
user_pipeline = self.dataset_args.pipeline
modifiers = session.get_modifiers()
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
pipeline(self.model, calibration_dataloader, self.dataset_args)

session.finalize()


def oneshot(**kwargs) -> PreTrainedModel:
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):

self.initialized_ = self.on_initialize(state=state, **kwargs)

# trigger start
# trigger starts
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
if self.should_start(fake_start_event):
self.on_start(state, fake_start_event, **kwargs)
Expand All @@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs):
:param state: The current state of the model
:param kwargs: Additional arguments for finalizing the modifier
"""
if self.finalized_ or not self.initialized_:
return
if self.finalized_:
raise RuntimeError("cannot finalize a modifier twice")

if not self.initialized_:
raise RuntimeError("cannot finalize an uninitialized modifier")
Expand Down
24 changes: 17 additions & 7 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- sparsify_weight
- on_finalize
Expand Down Expand Up @@ -90,6 +87,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the hessian of the input to the module

:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -108,10 +113,9 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
Expand Down Expand Up @@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
# TODO: modify lifecycle to end on finalize
if not self.ended_:
self.on_end(state, None) # remove hooks

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._hessians = dict()
self._num_samples = dict()
self._module_names = dict()
Expand Down
Loading