-
Notifications
You must be signed in to change notification settings - Fork 453
Expand file tree
/
Copy pathstate.py
More file actions
280 lines (249 loc) · 10.2 KB
/
state.py
File metadata and controls
280 lines (249 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
"""
Module for managing the state of the LLM Compressor.
This module provides classes for holding and updating the state information
related to data, hardware, and model compression.
"""
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from loguru import logger
from llmcompressor.core.events import Event
from llmcompressor.metrics import BaseLogger, LoggerManager
__all__ = ["State", "Data", "Hardware", "ModifiedState"]
@dataclass
class Data:
"""
A dataclass to hold different data sets for training, validation,
testing, and/or calibration. Each data set is a ModifiableData instance.
:param train: The training data set
:type train: Optional[Any]
:param val: The validation data set
:type val: Optional[Any]
:param test: The testing data set
:type test: Optional[Any]
:param calib: The calibration data set
:type calib: Optional[Any]
"""
train: Optional[Any] = None
val: Optional[Any] = None
test: Optional[Any] = None
calib: Optional[Any] = None
@dataclass
class Hardware:
"""
A dataclass to hold information about the hardware being used.
:param device: The current device being used for training
:type device: Optional[str]
:param devices: List of all devices to be used for training
:type devices: Optional[List[str]]
:param rank: The rank of the current device
:type rank: Optional[int]
:param world_size: The total number of devices being used
:type world_size: Optional[int]
:param local_rank: The local rank of the current device
:type local_rank: Optional[int]
:param local_world_size: The total number of devices being used on the local machine
:type local_world_size: Optional[int]
:param distributed: Whether or not distributed training is being used
:type distributed: Optional[bool]
:param distributed_strategy: The distributed strategy being used
:type distributed_strategy: Optional[str]
"""
device: Optional[str] = None
devices: Optional[List[str]] = None
rank: Optional[int] = None
world_size: Optional[int] = None
local_rank: Optional[int] = None
local_world_size: Optional[int] = None
distributed: Optional[bool] = None
distributed_strategy: Optional[str] = None
@dataclass
class State:
"""
State class holds information about the current compression state.
:param model: The model being used for compression
:type model: Any
:param teacher_model: The teacher model being used for compression
:type teacher_model: Any
:param optimizer: The optimizer being used for training
:type optimizer: Any
:param optim_wrapped: Whether or not the optimizer has been wrapped
:type optim_wrapped: bool
:param loss: The loss function being used for training
:type loss: Any
:param batch_data: The current batch of data being used for compression
:type batch_data: Any
:param data: The data sets being used for training, validation, testing,
and/or calibration, wrapped in a Data instance
:type data: Data
:param hardware: Hardware instance holding info about the target hardware being used
:type hardware: Hardware
:param start_event: The start event to begin compression
:type start_event: Event
:param last_event: The last compression event that occurred
:type last_event: Event
:param loggers: LoggerManager instance holding all the loggers to log
:type loggers: Optional[LoggerManager]
:param model_log_cadence: The cadence to log model information w.r.t epochs.
If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
:type model_log_cadence: Optional[float]
"""
model: Any = None
teacher_model: Any = None
optimizer: Any = None
optim_wrapped: bool = None
loss: Any = None
batch_data: Any = None
data: Data = field(default_factory=Data)
hardware: Hardware = field(default_factory=Hardware)
start_event: Optional[Event] = None
last_event: Optional[Event] = None
loggers: Optional[LoggerManager] = None
model_log_cadence: Optional[float] = None
_last_log_step: Union[float, int, None] = None
def update(
self,
model: Any = None,
teacher_model: Any = None,
optimizer: Any = None,
attach_optim_callbacks: bool = True,
train_data: Any = None,
val_data: Any = None,
test_data: Any = None,
calib_data: Any = None,
copy_data: bool = True,
start: float = None,
steps_per_epoch: int = None,
batches_per_step: int = None,
loggers: Union[None, LoggerManager, List[BaseLogger]] = None,
model_log_cadence: Optional[float] = None,
**kwargs,
) -> Dict:
"""
Update the state with the given parameters.
:param model: The model to update the state with
:type model: Any
:param teacher_model: The teacher model to update the state with
:type teacher_model: Any
:param optimizer: The optimizer to update the state with
:type optimizer: Any
:param attach_optim_callbacks: Whether or not to attach optimizer callbacks
:type attach_optim_callbacks: bool
:param train_data: The training data to update the state with
:type train_data: Any
:param val_data: The validation data to update the state with
:type val_data: Any
:param test_data: The testing data to update the state with
:type test_data: Any
:param calib_data: The calibration data to update the state with
:type calib_data: Any
:param copy_data: Whether or not to copy the data
:type copy_data: bool
:param start: The start index to update the state with
:type start: float
:param steps_per_epoch: The steps per epoch to update the state with
:type steps_per_epoch: int
:param batches_per_step: The batches per step to update the state with
:type batches_per_step: int
:param loggers: The metrics manager to setup logging important info and
milestones to, also accepts a list of BaseLogger(s)
:type loggers: Union[None, LoggerManager, List[BaseLogger]]
:param model_log_cadence: The cadence to log model information w.r.t epochs.
If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1.
:type model_log_cadence: Optional[float]
:param kwargs: Additional keyword arguments to update the state with
:return: The updated state as a dictionary
:rtype: Dict
"""
logger.debug(
"Updating state with provided parameters: {}",
{
"model": model,
"teacher_model": teacher_model,
"optimizer": optimizer,
"attach_optim_callbacks": attach_optim_callbacks,
"train_data": train_data,
"val_data": val_data,
"test_data": test_data,
"calib_data": calib_data,
"copy_data": copy_data,
"start": start,
"steps_per_epoch": steps_per_epoch,
"batches_per_step": batches_per_step,
"loggers": loggers,
"model_log_cadence": model_log_cadence,
"kwargs": kwargs,
},
)
if model is not None:
self.model = model
if teacher_model is not None:
self.teacher_model = teacher_model
if optimizer is not None:
self.optim_wrapped = attach_optim_callbacks
self.optimizer = optimizer
if train_data is not None:
self.data.train = train_data if not copy_data else deepcopy(train_data)
if val_data is not None:
self.data.val = val_data if not copy_data else deepcopy(val_data)
if test_data is not None:
self.data.test = test_data if not copy_data else deepcopy(test_data)
if calib_data is not None:
self.data.calib = calib_data if not copy_data else deepcopy(calib_data)
if "device" in kwargs:
self.hardware.device = kwargs["device"]
if (
start is not None
or steps_per_epoch is not None
or batches_per_step is not None
):
if self.start_event is None:
self.start_event = Event()
if start is not None:
self.start_event.current_index = start
if steps_per_epoch is not None:
self.start_event.steps_per_epoch = steps_per_epoch
if batches_per_step is not None:
self.start_event.batches_per_step = batches_per_step
loggers = loggers or []
if isinstance(loggers, list):
loggers = LoggerManager(loggers)
self.loggers = loggers
if model_log_cadence is not None:
self.model_log_cadence = model_log_cadence
return kwargs
@dataclass
class ModifiedState:
"""
A dataclass to represent a modified model, optimizer, and loss.
:param model: The modified model
:type model: Optional[Any]
:param optimizer: The modified optimizer
:type optimizer: Optional[Any]
:param loss: The modified loss
:type loss: Optional[Any]
:param modifier_data: The modifier data used to modify the
model, optimizer, and loss
:type modifier_data: Optional[List[Dict[str, Any]]]
"""
model: Optional[Any] = None
optimizer: Optional[Any] = None
loss: Optional[Any] = None
modifier_data: Optional[List[Dict[str, Any]]] = None
def __init__(self, model, optimizer, loss, modifier_data):
"""
Initialize the ModifiedState with the given parameters.
:param model: The modified model
:type model: Any
:param optimizer: The modified optimizer
:type optimizer: Any
:param loss: The modified loss
:type loss: Any
:param modifier_data: The modifier data used to modify the model, optimizer,
and loss
:type modifier_data: List[Dict[str, Any]]
"""
self.model = model
self.optimizer = optimizer
self.loss = loss
self.modifier_data = modifier_data