-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathsession_functions.py
More file actions
227 lines (190 loc) · 8.1 KB
/
session_functions.py
File metadata and controls
227 lines (190 loc) · 8.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import threading
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
from llmcompressor.core.state import ModifiedState
from llmcompressor.recipe import Recipe
__all__ = [
"create_session",
"active_session",
"reset_session",
"initialize",
"finalize",
"callbacks",
"LifecycleCallbacks",
]
_global_session = CompressionSession()
_local_storage = threading.local()
_local_storage.session = _global_session
@contextmanager
def create_session() -> CompressionSession:
"""
Context manager to create and yield a new session for sparsification.
This will set the active session to the new session for the duration
of the context.
:return: the new session
"""
global _local_storage
orig_session = getattr(_local_storage, "session", None)
new_session = CompressionSession()
_local_storage.session = new_session
try:
yield new_session
finally:
_local_storage.session = orig_session
def active_session() -> CompressionSession:
"""
:return: the active session for sparsification
"""
global _local_storage
return getattr(_local_storage, "session", _global_session)
def reset_session():
"""
Reset the currently active session to its initial state
"""
session = active_session()
session._lifecycle.reset()
def initialize(
recipe: Union[str, List[str], "Recipe", List["Recipe"], None] = None,
recipe_stage: Union[str, List[str], None] = None,
recipe_args: Optional[Dict[str, Any]] = None,
model: Optional[Any] = None,
teacher_model: Optional[Any] = None,
optimizer: Optional[Any] = None,
attach_optim_callbacks: bool = True,
train_data: Optional[Any] = None,
val_data: Optional[Any] = None,
test_data: Optional[Any] = None,
calib_data: Optional[Any] = None,
copy_data: bool = True,
start: Optional[float] = None,
steps_per_epoch: Optional[int] = None,
batches_per_step: Optional[int] = None,
**kwargs,
) -> ModifiedState:
"""
A method to initialize the active session for sparsification
:param recipe: the recipe to use for the sparsification, can be a path to a
recipe file, a raw recipe string, a recipe object, or a list of recipe objects.
:param recipe_stage: the stage to target for the sparsification
:param recipe_args: the args to use for overriding the recipe defaults
:param model: the model to sparsify
:param teacher_model: the teacher model to use for knowledge distillation
:param optimizer: the optimizer to use for the sparsification
:param attach_optim_callbacks: True to attach the optimizer callbacks to the
sparsification lifecycle, False otherwise
:param train_data: the training data to use for the sparsification
:param val_data: the validation data to use for the sparsification
:param test_data: the testing data to use for the sparsification
:param calib_data: the calibration data to use for the sparsification
:param copy_data: True to copy the data, False otherwise
:param start: the start epoch to use for the sparsification
:param steps_per_epoch: the number of steps per epoch to use for the
sparsification
:param batches_per_step: the number of batches per step to use for
sparsification
:param kwargs: additional kwargs to pass to the lifecycle's initialize method
:return: the modified state of the active session after initializing
"""
return active_session().initialize(
recipe=recipe,
recipe_stage=recipe_stage,
recipe_args=recipe_args,
model=model,
teacher_model=teacher_model,
optimizer=optimizer,
attach_optim_callbacks=attach_optim_callbacks,
train_data=train_data,
val_data=val_data,
test_data=test_data,
calib_data=calib_data,
copy_data=copy_data,
start=start,
steps_per_epoch=steps_per_epoch,
batches_per_step=batches_per_step,
**kwargs,
)
def finalize(**kwargs) -> ModifiedState:
"""
Method to finalize the active session for sparsification
:param kwargs: additional kwargs to pass to the lifecycle's finalize method
:return: the modified state of the active session after finalizing
"""
return active_session().finalize(**kwargs)
class LifecycleCallbacks:
"""
A class for invoking lifecycle events for the active session
"""
@classmethod
def event(cls, event_type: EventType, **kwargs) -> ModifiedState:
"""
Invoke an event for the active session
:param event_type: the event type to invoke
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
if event_type in [EventType.INITIALIZE, EventType.FINALIZE]:
raise ValueError(
f"Cannot invoke {event_type} event. "
f"Use the corresponding method instead."
)
# skip event callbacks if no recipe was provided
if not active_session().lifecycle.recipe_container.check_any_recipe_exists():
return
return active_session().event(event_type, **kwargs)
@classmethod
def batch_start(cls, batch_data: Optional[Any] = None, **kwargs) -> ModifiedState:
"""
Invoke a batch start event for the active session
:param batch_data: the batch data to use for the event
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
return cls.event(EventType.BATCH_START, batch_data=batch_data, **kwargs)
@classmethod
def loss_calculated(cls, loss: Optional[Any] = None, **kwargs) -> ModifiedState:
"""
Invoke a loss calculated event for the active session
:param loss: the loss to use for the event
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
# log loss if loss calculated
active_session()._log_loss(event_type=EventType.LOSS_CALCULATED, loss=loss)
return cls.event(EventType.LOSS_CALCULATED, loss=loss, **kwargs)
@classmethod
def optim_pre_step(cls, **kwargs) -> ModifiedState:
"""
Invoke an optimizer pre-step event for the active session
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
return cls.event(EventType.OPTIM_PRE_STEP, **kwargs)
@classmethod
def optim_post_step(cls, **kwargs) -> ModifiedState:
"""
Invoke an optimizer post-step event for the active session
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
return cls.event(EventType.OPTIM_POST_STEP, **kwargs)
@classmethod
def batch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a batch end event for the active session
:param kwargs: additional kwargs to pass to the current session's event method
:return: the modified state of the active session after invoking the event
"""
active_session()._log_model_info()
return cls.event(EventType.BATCH_END, **kwargs)
@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch
This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
callbacks = LifecycleCallbacks