Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ jobs:
- name: Running KV Cache Tests
if: (success() || failure()) && steps.install.outcome == 'success'
run: |
pytest -v tests/llmcompressor/transformers/kv_cache -k "not test_kv_cache_gptq_model_state_dict_attr"
pytest -v tests/llmcompressor/transformers/kv_cache
9 changes: 9 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from transformers import DefaultDataCollator

from llmcompressor.pipelines.registry import PIPELINES


@dataclass
class DVCDatasetArguments:
Expand Down Expand Up @@ -171,3 +173,10 @@ class DatasetArguments(CustomDatasetArguments):
"will execute code present on the Hub on your local machine."
},
)
pipeline: Optional[str] = field(
default="independent",
metadata={
"help": "Calibration pipeline used to calibrate model. "
f"Options: {PIPELINES.keys()}"
},
)
8 changes: 8 additions & 0 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class EventType(Enum):
:param BATCH_START: Event type for the start of a batch.
:param LOSS_CALCULATED: Event type for when loss is calculated.
:param BATCH_END: Event type for the end of a batch.
:param CALIBRATION_EPOCH_START: Event type for the start of a calibration epoch.
:param SEQUENTIAL_EPOCH_END: Event type for the end of a layer calibration epoch
:param CALIBRATION_EPOCH_END: Event type for the end of a calibration epoch.
:param OPTIM_PRE_STEP: Event type for pre-optimization step.
:param OPTIM_POST_STEP: Event type for post-optimization step.
"""
Expand All @@ -45,6 +48,11 @@ class EventType(Enum):
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"

# calibration lifecycle (TODO: support batched calibration lifecycles)
CALIBRATION_EPOCH_START = "calibration_epoch_start"
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
CALIBRATION_EPOCH_END = "calibration_epoch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
OPTIM_POST_STEP = "optim_post_step"
Expand Down
5 changes: 3 additions & 2 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def initialize(
return

logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
if not (recipe is recipe_stage is recipe_args is None):
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()

mod_data = []
Expand Down
8 changes: 8 additions & 0 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def get_serialized_recipe(self) -> Optional[str]:

logger.warning("Recipe not found in session - it may have been reset")

def get_modifiers(self):
stage_modifiers = self.lifecycle.modifiers
return [
modifier
for stage_modifier in stage_modifiers
for modifier in stage_modifier.modifiers
] # noqa: E127

def _log_model_info(self):
# Log model level logs if cadence reached
current_index = self._lifecycle.global_step
Expand Down
35 changes: 33 additions & 2 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Optional
from typing import Any, Generator, Optional

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
Expand All @@ -21,7 +21,7 @@


@contextmanager
def create_session() -> CompressionSession:
def create_session() -> Generator[CompressionSession, None, None]:
"""
Context manager to create and yield a new session for sparsification.
This will set the active session to the new session for the duration
Expand Down Expand Up @@ -136,5 +136,36 @@ def batch_end(cls, **kwargs) -> ModifiedState:
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)

@classmethod
def calibration_epoch_start(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch start event for the active session during calibration. This event
should be called before calibration starts for one epoch
see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs)

@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch
This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)

@classmethod
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch end event for the active session during calibration. This event
should be called after the model has been calibrated for one epoch
see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)


callbacks = LifecycleCallbacks
19 changes: 10 additions & 9 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.pipelines.registry import get_pipeline_fn

__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -157,21 +158,21 @@ def apply_recipe_modifiers(
"""

session = active_session()
session.reset()

session_kwargs = dict(
session.initialize(
model=self.model,
start=-1,
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
recipe_stage=recipe_stage,
recipe_args=self.recipe_args.recipe_args,
)

session.reset()
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)
modifiers = session.get_modifiers()
_, pipeline_fn = get_pipeline_fn(self.dataset_args.pipeline, modifiers)
pipeline_fn(self.model, calibration_dataloader, self.dataset_args)

session.finalize()


def oneshot(**kwargs) -> PreTrainedModel:
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):

self.initialized_ = self.on_initialize(state=state, **kwargs)

# trigger start
# trigger starts
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
if self.should_start(fake_start_event):
self.on_start(state, fake_start_event, **kwargs)
Expand All @@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs):
:param state: The current state of the model
:param kwargs: Additional arguments for finalizing the modifier
"""
if self.finalized_ or not self.initialized_:
return
if self.finalized_:
raise RuntimeError("cannot finalize a modifier twice")

if not self.initialized_:
raise RuntimeError("cannot finalize an uninitialized modifier")
Expand Down
24 changes: 17 additions & 7 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- sparsify_weight
- on_finalize
Expand Down Expand Up @@ -90,6 +87,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the hessian of the input to the module
:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -108,10 +113,9 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
Expand Down Expand Up @@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
# TODO: modify lifecycle to end on finalize
if not self.ended_:
self.on_end(state, None) # remove hooks

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._hessians = dict()
self._num_samples = dict()
self._module_names = dict()
Expand Down
Loading
Loading