-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathlifecycle.py
More file actions
269 lines (225 loc) · 9.46 KB
/
lifecycle.py
File metadata and controls
269 lines (225 loc) · 9.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""
Module for managing the compression lifecycle in the LLM Compressor.
This module provides a class for defining and managing the lifecycle of compression
events, including initialization, finalization, and event handling.
"""
from dataclasses import dataclass, field
from typing import Any, List, Optional
from loguru import logger
from llmcompressor.core.events import (
CallbacksEventLifecycle,
EventLifecycle,
EventType,
OptimizerEventLifecycle,
)
from llmcompressor.core.state import State
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import (
RecipeArgsInput,
RecipeContainer,
RecipeInput,
RecipeStageInput,
)
__all__ = ["CompressionLifecycle"]
@dataclass
class CompressionLifecycle:
"""
A class for managing the lifecycle of compression events in the LLM Compressor.
:param state: The current state of the compression process
:type state: Optional[State]
:param recipe_container: The container for the compression recipe
:type recipe_container: RecipeContainer
:param modifiers: The list of stage modifiers
:type modifiers: List[StageModifiers]
:param event_lifecycle: The event lifecycle manager
:type event_lifecycle: Optional[EventLifecycle]
"""
state: State = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
event_lifecycle: Optional[EventLifecycle] = None
initialized_: bool = False
finalized: bool = False
def reset(self):
"""
Reset the compression lifecycle, finalizing any active modifiers
and resetting all attributes.
"""
logger.debug("Resetting compression lifecycle")
for mod in self.modifiers:
if not mod.initialized or mod.finalized:
continue
try:
mod.finalize(self.state)
logger.debug("Finalized modifier: {}", mod)
except Exception as e:
logger.warning(f"Exception during finalizing modifier: {e}")
self.__init__()
logger.info("Compression lifecycle reset")
def initialize(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
**kwargs,
) -> List[Any]:
"""
Initialize the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
if self.initialized_:
raise ValueError(
"Initialize was called twice. To update state values prior to "
"initialization, please use `active_session().state.update()`"
)
self.state.update(**kwargs)
logger.debug("Initializing compression lifecycle")
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()
mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **kwargs)
logger.debug("Initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)
self.initialized_ = True
logger.info(
"Compression lifecycle initialized for {} modifiers", len(self.modifiers)
)
return mod_data
def finalize(self, **kwargs) -> List[Any]:
"""
Finalize the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from finalizing modifiers
:rtype: List[Any]
:raises ValueError: If called before initialization or more than once
"""
if not self.initialized_:
logger.error("Cannot finalize before initializing")
raise ValueError("Cannot finalize before initializing")
if self.finalized:
logger.error("Cannot finalize more than once")
raise ValueError("Cannot finalize more than once")
logger.debug("Finalizing compression lifecycle")
mod_data = []
for mod in self.modifiers:
data = mod.finalize(state=self.state, **kwargs)
logger.debug("Finalized modifier: {}", mod)
if data is not None:
mod_data.append(data)
self.finalized = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle finalized for {} modifiers", len(self.modifiers)
)
return mod_data
def event(self, event_type: EventType, **kwargs) -> List[Any]:
"""
Handle a compression event.
:param event_type: The type of event to handle
:type event_type: EventType
:param kwargs: Additional arguments to pass to the event handlers
:return: List of data returned from handling the event by modifiers
:rtype: List[Any]
:raises ValueError: If called before initialization, after finalization,
or for an invalid event type
"""
if not self.initialized_:
logger.error("Cannot invoke event before initializing")
raise ValueError("Cannot invoke event before initializing")
if self.finalized:
logger.error("Cannot invoke event after finalizing")
raise ValueError("Cannot invoke event after finalizing")
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
logger.error(
"Cannot invoke {} event. Use the corresponding method instead.",
event_type,
)
raise ValueError(
f"Cannot invoke {event_type} event. "
f"Use the corresponding method instead."
)
if event_type == EventType.LOSS_CALCULATED and (
"loss" not in kwargs or kwargs["loss"] is None
):
logger.error("Loss must be provided for loss calculated event")
raise ValueError("Loss must be provided for loss calculated event")
logger.debug("Handling event: {}", event_type)
self._check_setup_event_lifecycle(event_type)
event = None
mod_data = []
for event in self.event_lifecycle.events_from_type(event_type):
if self.state.start_event is None:
self.state.start_event = event
for mod in self.modifiers:
data = mod.update_event(state=self.state, event=event, **kwargs)
logger.debug("Updated event with modifier: {}", mod)
if data is not None:
mod_data.append(data)
assert (
event is not None
), f"Event lifecycle did not return an event for {event_type}"
self.state.last_event = event
return mod_data
def _check_setup_event_lifecycle(self, event_type: EventType):
if self.event_lifecycle is not None:
return
if (
self.state is None
or self.state.model is None
or self.state.start_event is None
or self.recipe_container.compiled_recipe is None
):
logger.error("Cannot invoke event before recipe, model, and start are set")
raise ValueError(
"Cannot invoke event before recipe, model, and start are set"
)
logger.debug("Setting up event lifecycle for event type: {}", event_type)
for mod in self.modifiers:
logger.debug("Checking if modifier is initialized: {}", mod)
mod.check_initialized()
# first check for creation of a callbacks event lifecycle
# must start with BATCH_START event
if event_type == EventType.BATCH_START:
self.event_lifecycle = CallbacksEventLifecycle(
type_first=EventType.BATCH_START, start=self.state.start_event
)
elif (
event_type == EventType.LOSS_CALCULATED
or event_type == EventType.OPTIM_PRE_STEP
):
self.event_lifecycle = OptimizerEventLifecycle(
type_first=event_type, start=self.state.start_event
)
else:
logger.error(
"Invalid event type for initializing event lifecycle: "
"{}. Must be BATCH_START, LOSS_CALCULATED, or OPTIM_PRE_STEP",
event_type,
)
raise ValueError(
f"Invalid event type for initializing event lifecycle: "
f"{event_type}. Must be BATCH_START, LOSS_CALCULATED, or OPTIM_PRE_STEP"
)
logger.info(
"Event lifecycle for compression lifecycle created: "
"{} with start event type: {}",
self.event_lifecycle,
event_type,
)
def _set_model_layer_prefix(self):
compiled_recipe = self.recipe_container.compiled_recipe
if (
compiled_recipe is None
or (metadata := compiled_recipe.metadata) is None
or (model_metadata := metadata.target_model) is None
):
return False
self.state.model.layer_prefix = model_metadata.layer_prefix
logger.debug("Model layer prefix set to {}", self.state.model.layer_prefix)
return True