-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathlifecycle.py
More file actions
245 lines (200 loc) · 8.41 KB
/
lifecycle.py
File metadata and controls
245 lines (200 loc) · 8.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
Module for managing the compression lifecycle in the LLM Compressor.
This module provides a class for defining and managing the lifecycle of compression
events, including initialization, finalization, and event handling.
"""
from dataclasses import dataclass, field
from typing import Any, List, Optional
from loguru import logger
from llmcompressor.core.events import Event, EventType
from llmcompressor.core.state import State
from llmcompressor.modifiers import StageModifiers
from llmcompressor.recipe import (
RecipeArgsInput,
RecipeContainer,
RecipeInput,
RecipeStageInput,
)
__all__ = ["CompressionLifecycle"]
@dataclass
class CompressionLifecycle:
"""
A class for managing the lifecycle of compression events in the LLM Compressor.
:param state: The current state of the compression process
:type state: Optional[State]
:param recipe_container: The container for the compression recipe
:type recipe_container: RecipeContainer
:param modifiers: The list of stage modifiers
:type modifiers: List[StageModifiers]
"""
state: State = field(default_factory=State)
recipe_container: RecipeContainer = field(default_factory=RecipeContainer)
modifiers: List[StageModifiers] = field(default_factory=list)
initialized_: bool = False
finalized: bool = False
# event order validation
_last_event_type: Optional[EventType] = EventType.BATCH_END
_event_order: List[EventType] = field(
default_factory=lambda: [
EventType.BATCH_START,
EventType.LOSS_CALCULATED,
EventType.OPTIM_PRE_STEP,
EventType.OPTIM_POST_STEP,
EventType.BATCH_END,
]
)
# track global step in training (could be epoch/batch)
global_step: int = 0
def reset(self):
"""
Reset the compression lifecycle, finalizing any active modifiers
and resetting all attributes.
"""
logger.debug("Resetting compression lifecycle")
for mod in self.modifiers:
if not mod.initialized or mod.finalized:
continue
try:
mod.finalize(self.state)
logger.debug("Finalized modifier: {}", mod)
except Exception as e:
logger.warning(f"Exception during finalizing modifier: {e}")
self.__init__()
logger.info("Compression lifecycle reset")
def initialize(
self,
recipe: Optional[RecipeInput] = None,
recipe_stage: Optional[RecipeStageInput] = None,
recipe_args: Optional[RecipeArgsInput] = None,
**kwargs,
) -> List[Any]:
"""
Initialize the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from initialization of modifiers
:rtype: List[Any]
"""
self.state.update(**kwargs)
if self.initialized_: # TODO: do not initialize twice
return
logger.debug("Initializing compression lifecycle")
if not (recipe is recipe_stage is recipe_args is None):
self.recipe_container.append(recipe, recipe_stage, recipe_args)
self.modifiers = self.recipe_container.get_modifiers()
self._set_model_layer_prefix()
mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **kwargs)
logger.debug("Initialized modifier: {}", mod)
if data is not None:
mod_data.append(data)
self.initialized_ = True
logger.info(
"Compression lifecycle initialized for {} modifiers", len(self.modifiers)
)
return mod_data
def finalize(self, **kwargs) -> List[Any]:
"""
Finalize the compression lifecycle.
:param kwargs: Additional arguments to update the state with
:return: List of data returned from finalizing modifiers
:rtype: List[Any]
:raises ValueError: If called before initialization or more than once
"""
if not self.initialized_:
logger.error("Cannot finalize before initializing")
raise ValueError("Cannot finalize before initializing")
if self.finalized:
logger.error("Cannot finalize more than once")
raise ValueError("Cannot finalize more than once")
logger.debug("Finalizing compression lifecycle")
mod_data = []
for mod in self.modifiers:
data = mod.finalize(state=self.state, **kwargs)
logger.debug("Finalized modifier: {}", mod)
if data is not None:
mod_data.append(data)
self.finalized = True
applied_stage_names = [mod.unique_id for mod in self.modifiers if mod.applied]
self.recipe_container.update_applied_stages(applied_stage_names)
logger.info(
"Compression lifecycle finalized for {} modifiers", len(self.modifiers)
)
return mod_data
def event(
self, event_type: EventType, global_step: Optional[int] = 0, **kwargs
) -> List[Any]:
"""
Handle a compression event.
:param event_type: The type of event to handle
:type event_type: EventType
:param kwargs: Additional arguments to pass to the event handlers
:return: List of data returned from handling the event by modifiers
:rtype: List[Any]
:raises ValueError: If called before initialization, after finalization,
or for an invalid event type
"""
if not self.initialized_:
logger.error("Cannot invoke event before initializing")
raise ValueError("Cannot invoke event before initializing")
if self.finalized:
logger.error("Cannot invoke event after finalizing")
raise ValueError("Cannot invoke event after finalizing")
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
logger.error(
"Cannot invoke {} event. Use the corresponding method instead.",
event_type,
)
raise ValueError(
f"Cannot invoke {event_type} event. "
f"Use the corresponding method instead."
)
if not self._validate_event_order(event_type):
raise ValueError(
f"Lifecycle events must appear following order: {self._event_order}. "
f"Instead, {self._last_event_type} was called before {event_type}"
)
if event_type == EventType.LOSS_CALCULATED and (
"loss" not in kwargs or kwargs["loss"] is None
):
logger.error("Loss must be provided for loss calculated event")
raise ValueError("Loss must be provided for loss calculated event")
logger.debug("Handling event: {}", event_type)
# update global step
if global_step is not None:
self.global_step = global_step
event = Event(type_=event_type)
mod_data = []
for mod in self.modifiers:
data = mod.update_event(state=self.state, event=event, **kwargs)
logger.debug("Updated event with modifier: {}", mod)
if data is not None:
mod_data.append(data)
assert (
event is not None
), f"Event lifecycle did not return an event for {event_type}"
return mod_data
def _validate_event_order(self, event_type: EventType) -> bool:
if event_type not in self._event_order:
# for unhandled events, do not save last event
return True
if event_type == EventType.BATCH_START:
valid = self._last_event_type != EventType.BATCH_START
else:
last_event_index = self._event_order.index(self._last_event_type)
curr_event_index = self._event_order.index(event_type)
valid = last_event_index <= curr_event_index
if valid:
self._last_event_type = event_type
return valid
def _set_model_layer_prefix(self):
compiled_recipe = self.recipe_container.compiled_recipe
if (
compiled_recipe is None
or (metadata := compiled_recipe.metadata) is None
or (model_metadata := metadata.target_model) is None
):
return False
self.state.model.layer_prefix = model_metadata.layer_prefix
logger.debug("Model layer prefix set to {}", self.state.model.layer_prefix)
return True