Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-check-transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ jobs:
- name: Running KV Cache Tests
if: (success() || failure()) && steps.install.outcome == 'success'
run: |
pytest -v tests/llmcompressor/transformers/kv_cache
pytest -v tests/llmcompressor/transformers/kv_cache -k "not test_kv_cache_gptq_model_state_dict_attr"
16 changes: 16 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from transformers import DefaultDataCollator

from llmcompressor.pipelines.registry import PIPELINES


@dataclass
class DVCDatasetArguments:
Expand Down Expand Up @@ -171,3 +173,17 @@ class DatasetArguments(CustomDatasetArguments):
"will execute code present on the Hub on your local machine."
},
)
pipeline: Optional[str] = field(
default="independent",
metadata={
"help": "Calibration pipeline used to calibrate model. "
f"Options: {PIPELINES.keys()}"
},
)
tracing_ignore: List[str] = field(
default_factory=lambda: ["_update_causal_mask"],
metadata={
"help": "List of functions to ignore during tracing, either "
"{module}.{method_name} or {function_name}"
},
)
2 changes: 2 additions & 0 deletions src/llmcompressor/core/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class EventType(Enum):
BATCH_START = "batch_start"
LOSS_CALCULATED = "loss_calculated"
BATCH_END = "batch_end"
SEQUENTIAL_EPOCH_END = "sequential_epoch_end"
CALIBRATION_EPOCH_END = "calibration_epoch_end"

# step lifecycle
OPTIM_PRE_STEP = "optim_pre_step"
Expand Down
15 changes: 11 additions & 4 deletions src/llmcompressor/core/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def reset(self):
self.__init__()
logger.info("Compression lifecycle reset")

def initialize_recipe(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
):
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()

def initialize(
self,
recipe: Optional[RecipeInput] = None,
Expand All @@ -92,12 +101,10 @@ def initialize(
:rtype: List[Any]
"""
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return

logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
if not (recipe is recipe_stage is recipe_args is None):
self.initialize_recipe(recipe, recipe_stage, recipe_args)
self._set_model_layer_prefix()

mod_data = []
Expand Down
8 changes: 8 additions & 0 deletions src/llmcompressor/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ def get_serialized_recipe(self) -> Optional[str]:

logger.warning("Recipe not found in session - it may have been reset")

def get_modifiers(self):
stage_modifiers = self.lifecycle.modifiers
return [
modifier
for stage_modifier in stage_modifiers
for modifier in stage_modifier.modifiers
] # noqa: E127

def _log_model_info(self):
# Log model level logs if cadence reached
current_index = self._lifecycle.global_step
Expand Down
25 changes: 23 additions & 2 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Optional
from typing import Any, Generator, Optional

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
Expand All @@ -21,7 +21,7 @@


@contextmanager
def create_session() -> CompressionSession:
def create_session() -> Generator[CompressionSession, None, None]:
"""
Context manager to create and yield a new session for sparsification.
This will set the active session to the new session for the duration
Expand Down Expand Up @@ -136,5 +136,26 @@ def batch_end(cls, **kwargs) -> ModifiedState:
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)

@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch

This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)

@classmethod
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a epoch end event for the active session during calibration. This event
should be called after the model has been calibrated for one epoch

see `src/llmcompressor/pipelines/basic/pipeline.py` for usage example
"""
return cls.event(EventType.CALIBRATION_EPOCH_END, **kwargs)


callbacks = LifecycleCallbacks
20 changes: 10 additions & 10 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llmcompressor.core.session_functions import active_session
from llmcompressor.datasets import get_calibration_dataloader
from llmcompressor.entrypoints.utils import post_process, pre_process
from llmcompressor.pipelines.registry import get_pipeline_fn

__all__ = ["Oneshot", "oneshot"]

Expand Down Expand Up @@ -157,21 +158,20 @@ def apply_recipe_modifiers(
"""

session = active_session()
session.reset()

session_kwargs = dict(
model=self.model,
session.lifecycle.state.update(model=self.model, start=-1)
session.lifecycle.initialize_recipe(
recipe=self.recipe,
recipe_args=self.recipe_args.recipe_args,
calib_data=calibration_dataloader,
start=-1, # oneshot-specific argument
copy_data=False,
min_tokens_per_module=getattr(self, "min_tokens_per_module", None),
recipe_stage=recipe_stage,
recipe_args=self.recipe_args.recipe_args,
)

session.reset()
session.initialize(**session_kwargs)
session.finalize(**session_kwargs)
modifiers = session.get_modifiers()
_, pipeline_fn = get_pipeline_fn(self.dataset_args.pipeline, modifiers)
pipeline_fn(self.model, calibration_dataloader, self.dataset_args)

session.finalize()


def oneshot(**kwargs) -> PreTrainedModel:
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def initialize(self, state: State, **kwargs):

self.initialized_ = self.on_initialize(state=state, **kwargs)

# trigger start
# trigger starts
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
if self.should_start(fake_start_event):
self.on_start(state, fake_start_event, **kwargs)
Expand All @@ -103,8 +103,8 @@ def finalize(self, state: State, **kwargs):
:param state: The current state of the model
:param kwargs: Additional arguments for finalizing the modifier
"""
if self.finalized_ or not self.initialized_:
return
if self.finalized_:
raise RuntimeError("cannot finalize a modifier twice")

if not self.initialized_:
raise RuntimeError("cannot finalize an uninitialized modifier")
Expand Down
24 changes: 17 additions & 7 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier):
Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_hessian
- accumulate_hessian
- on_sequential_batch_end
- sparsify_weight
- on_finalize
Expand Down Expand Up @@ -90,6 +87,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the hessian of the input to the module

:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -108,10 +113,9 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
Expand Down Expand Up @@ -152,7 +156,13 @@ def _maybe_onload_hessian(self, module: torch.nn.Module):
self._hessians[module] = self._hessians[module].to(device="cpu")

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
# TODO: modify lifecycle to end on finalize
if not self.ended_:
self.on_end(state, None) # remove hooks

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._hessians = dict()
self._num_samples = dict()
self._module_names = dict()
Expand Down
71 changes: 23 additions & 48 deletions src/llmcompressor/modifiers/obcq/sgpt_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@
from loguru import logger
from pydantic import Field, PrivateAttr, field_validator, model_validator

from llmcompressor.core import State
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers.modifier import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.basic import run_pipeline as run_basic
from llmcompressor.pipelines.layer_sequential import (
run_pipeline as run_layer_sequential,
)
from llmcompressor.pipelines.sequential import run_pipeline as run_sequential
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
Expand All @@ -24,7 +21,7 @@
)


class SparsityModifierMixin(HooksMixin):
class SparsityModifierMixin(Modifier):
# modifier arguments
sparsity: Optional[Union[float, List[float]]]
sparsity_profile: Optional[str] = None
Expand Down Expand Up @@ -97,6 +94,10 @@ def calibrate_module(
):
raise NotImplementedError()

@abstractmethod
def compress_modules(self):
raise NotImplementedError()

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
Expand Down Expand Up @@ -160,48 +161,22 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
self._module_sparsities[module] = layer_sparsity
self.register_hook(module, self.calibrate_module, "forward")

# infer and run pipeline
model_name = state.model.__class__.__name__
input_names = dataloader.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self.ignore,
self,
)
return True

except Exception as exception:
if isinstance(exception, torch.fx.proxy.TraceError):
warnings.warn(f"Failed to trace {model_name} with inputs {input_names}")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn("Falling back to layer_sequential pipeline")
try:
run_layer_sequential(
state.model,
state.data.calib,
self.sequential_targets,
self,
)
return True

except Exception as exception:
if isinstance(exception, TypeError):
warnings.warn(f"{model_name} fails layer-wise assumptions")
if isinstance(exception, unfixable_errors):
raise exception

warnings.warn(
"Falling back to basic pipeline, which requires extra memory and "
"may result in decreased accuracy"
)
run_basic(state.model, state.data.calib, self)
return True
return True

def on_event(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
self.compress_modules()

if event.type_ == EventType.CALIBRATION_EPOCH_END:
self.compress_modules()

# TODO: modify lifecycle to end on calibration epoch end
if not self.ended_:
self.on_end(state, None)

def on_end(self, state: State, event: Event, **kwargs):
self.ended_ = True # TODO: move to super call
self.remove_hooks()

def _infer_sequential_targets(
self, model: torch.nn.Module
Expand Down
22 changes: 17 additions & 5 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def calibrate_module(
args: Tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
"""
Calibration hook used to accumulate the row scalars of the input to the module

:param module: module being calibrated
:param args: inputs to the module, the first element of which is the
cannonical input
:param _output: uncompressed module output, unused
"""
# Assume that the first argument is the input
inp = args[0]

Expand All @@ -91,12 +99,10 @@ def calibrate_module(
self._num_samples[module],
)

def on_sequential_batch_end(self):
def compress_modules(self):
"""
Sparsify modules
TODO: implement with event callback
Sparsify modules which have been calibrated
"""

for module in list(self._num_samples.keys()):
name = self._module_names[module]
sparsity = self._module_sparsities[module]
Expand All @@ -120,7 +126,13 @@ def on_sequential_batch_end(self):
del self._num_samples[module]

def on_finalize(self, state: State, **kwargs) -> bool:
self.remove_hooks()
# TODO: modify lifecycle to end on finalize
if not self.ended_:
self.on_end(state, None) # remove hooks

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._row_scalars = dict()
self._num_samples = dict()
self._module_names = dict()
Expand Down
Loading