99from loguru import logger
1010from pydantic import PrivateAttr
1111
12- from llmcompressor .core import State
12+ from llmcompressor .core import Event , EventType , State
1313from llmcompressor .modifiers import Modifier
1414from llmcompressor .modifiers .obcq .sgpt_mixin import SparsityModifierMixin
1515from llmcompressor .modifiers .pruning .wanda .wanda_sparsify import (
@@ -74,6 +74,14 @@ def calibrate_module(
7474 args : Tuple [torch .Tensor , ...],
7575 _output : torch .Tensor ,
7676 ):
77+ """
78+ Calibration hook used to accumulate the row scalars of the input to the module
79+
80+ :param module: module being calibrated
81+ :param args: inputs to the module, the first element of which is the
82+ cannonical input
83+ :param _output: uncompressed module output, unused
84+ """
7785 # Assume that the first argument is the input
7886 inp = args [0 ]
7987
@@ -91,12 +99,17 @@ def calibrate_module(
9199 self ._num_samples [module ],
92100 )
93101
94- def on_sequential_batch_end (self ):
102+ def on_event (self , state : State , event : Event , ** kwargs ):
103+ if event .type_ in (
104+ EventType .SEQUENTIAL_EPOCH_END ,
105+ EventType .CALIBRATION_EPOCH_END ,
106+ ):
107+ self .compress_modules ()
108+
109+ def compress_modules (self ):
95110 """
96- Sparsify modules
97- TODO: implement with event callback
111+ Sparsify modules which have been calibrated
98112 """
99-
100113 for module in list (self ._num_samples .keys ()):
101114 name = self ._module_names [module ]
102115 sparsity = self ._module_sparsities [module ]
@@ -120,6 +133,9 @@ def on_sequential_batch_end(self):
120133 del self ._num_samples [module ]
121134
122135 def on_finalize (self , state : State , ** kwargs ) -> bool :
136+ if len (self ._num_samples ) > 0 :
137+ raise ValueError (f"Failed to compress { len (self ._num_samples )} modules" )
138+
123139 self .remove_hooks ()
124140 self ._row_scalars = dict ()
125141 self ._num_samples = dict ()
0 commit comments