1818)
1919from llmcompressor .core .state import State
2020from llmcompressor .modifiers import StageModifiers
21- from llmcompressor .recipe import RecipeContainer
21+ from llmcompressor .recipe import (
22+ RecipeArgsInput ,
23+ RecipeContainer ,
24+ RecipeInput ,
25+ RecipeStageInput ,
26+ )
2227
2328__all__ = ["CompressionLifecycle" ]
2429
@@ -38,7 +43,7 @@ class CompressionLifecycle:
3843 :type event_lifecycle: Optional[EventLifecycle]
3944 """
4045
41- state : Optional [ State ] = None
46+ state : State = field ( default_factory = State )
4247 recipe_container : RecipeContainer = field (default_factory = RecipeContainer )
4348 modifiers : List [StageModifiers ] = field (default_factory = list )
4449 event_lifecycle : Optional [EventLifecycle ] = None
@@ -62,63 +67,35 @@ def reset(self):
6267 except Exception as e :
6368 logger .warning (f"Exception during finalizing modifier: { e } " )
6469
65- self .state = None
66- self .recipe_container = RecipeContainer ()
67- self .modifiers = []
68- self .event_lifecycle = None
69-
70- self .initialized_ = False
71- self .finalized = False
70+ self .__init__ ()
7271 logger .info ("Compression lifecycle reset" )
7372
74- def pre_initialize_structure (self , ** kwargs ) -> List [Any ]:
75- """
76- Pre-initialize the structure of the compression lifecycle.
77-
78- :param kwargs: Additional arguments to update the state with
79- :return: List of data returned from pre-initialization of modifiers
80- :rtype: List[Any]
81- """
82- logger .debug ("Pre-initializing structure" )
83- self ._check_create_state ()
84- extras = self .state .update (** kwargs )
85- extras = self .recipe_container .update (** extras )
86-
87- self ._check_compile_recipe ()
88- mod_data = []
89- for mod in self .modifiers :
90- data = mod .pre_initialize_structure (state = self .state , ** extras )
91- logger .debug ("Pre-initialized modifier: {}" , mod )
92- if data is not None :
93- mod_data .append (data )
94-
95- applied_stage_names = [mod .unique_id for mod in self .modifiers if mod .applied ]
96- self .recipe_container .update_applied_stages (applied_stage_names )
97- logger .info (
98- "Compression lifecycle structure pre-initialized for {} modifiers" ,
99- len (self .modifiers ),
100- )
101-
102- return mod_data
103-
104- def initialize (self , ** kwargs ) -> List [Any ]:
73+ def initialize (
74+ self ,
75+ recipe : Optional [RecipeInput ] = None ,
76+ recipe_stage : Optional [RecipeStageInput ] = None ,
77+ recipe_args : Optional [RecipeArgsInput ] = None ,
78+ ** kwargs ,
79+ ) -> List [Any ]:
10580 """
10681 Initialize the compression lifecycle.
10782
10883 :param kwargs: Additional arguments to update the state with
10984 :return: List of data returned from initialization of modifiers
11085 :rtype: List[Any]
11186 """
112- logger .debug ("Initializing compression lifecycle" )
113- self ._check_create_state ()
114- extras = self .state .update (** kwargs )
115- extras = self .recipe_container .update (** extras )
87+ self .state .update (** kwargs )
88+ if self .initialized_ : # TODO: do not initialize twice
89+ return
11690
117- self ._check_compile_recipe ()
91+ logger .debug ("Initializing compression lifecycle" )
92+ self .recipe_container .append (recipe , recipe_stage , recipe_args )
93+ self .modifiers = self .recipe_container .get_modifiers ()
11894 self ._set_model_layer_prefix ()
95+
11996 mod_data = []
12097 for mod in self .modifiers :
121- data = mod .initialize (state = self .state , ** extras )
98+ data = mod .initialize (state = self .state , ** kwargs )
12299 logger .debug ("Initialized modifier: {}" , mod )
123100 if data is not None :
124101 mod_data .append (data )
@@ -185,7 +162,7 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
185162 logger .error ("Cannot invoke event after finalizing" )
186163 raise ValueError ("Cannot invoke event after finalizing" )
187164
188- if event_type in [EventType .PRE_INIT , EventType . INITIALIZE , EventType .FINALIZE ]:
165+ if event_type in [EventType .INITIALIZE , EventType .FINALIZE ]:
189166 logger .error (
190167 "Cannot invoke {} event. Use the corresponding method instead." ,
191168 event_type ,
@@ -223,30 +200,6 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
223200
224201 return mod_data
225202
226- def _check_create_state (self ):
227- if self .state is not None :
228- return
229-
230- logger .debug ("Creating new State instance for compression lifecycle" )
231- self .state = State ()
232- logger .info ("State created for compression lifecycle" )
233-
234- def _check_compile_recipe (self ):
235- if not self .recipe_container .check_compile_recipe ():
236- return
237-
238- logger .debug (
239- "Compiling recipe and creating modifiers for compression lifecycle"
240- )
241- self .modifiers = self .recipe_container .compiled_recipe .create_modifier ()
242- for mod in self .modifiers :
243- if mod .unique_id in self .recipe_container .applied_stages :
244- mod .applied = True
245- logger .info (
246- "Recipe compiled and {} modifiers created" ,
247- len (self .modifiers ),
248- )
249-
250203 def _check_setup_event_lifecycle (self , event_type : EventType ):
251204 if self .event_lifecycle is not None :
252205 return
0 commit comments