-
Notifications
You must be signed in to change notification settings - Fork 467
Expand file tree
/
Copy pathrunner.py
More file actions
294 lines (252 loc) · 11.2 KB
/
runner.py
File metadata and controls
294 lines (252 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import math
import os
import re
from typing import List, Optional
import torch
from loguru import logger
from torch.utils.data import Dataset
from llmcompressor.args import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.core import active_session
from llmcompressor.pytorch.model_load.helpers import (
get_completed_stages,
get_session_model,
save_checkpoint,
save_completed_stages,
)
from llmcompressor.pytorch.utils import tensors_to_device
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.data import TextGenerationDataset
from llmcompressor.transformers.finetune.data.data_helpers import (
format_calibration_data,
make_dataset_splits,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model
class StageRunner:
"""
Launcher class for train, eval and one_shot flows. Manages data splits for each
flow and configurations. In the future this class will also handle alternating
between the different flows
LifeCycle
- populate_datasets()
- set_trainer()
- train() / evaluate() / predict()
:param model_args: Arguments pertaining to model/config/processor
:param data_args: Arguments pertaining to what data to use for different flows
:param training_args: Arguments pertaining to training loop configuration
:model: unwrapped model to run flows on
"""
def __init__(
self,
data_args: "DatasetArguments",
model_args: "ModelArguments",
training_args: "TrainingArguments",
recipe_args: "RecipeArguments",
):
self._data_args = data_args
self._model_args = model_args
self._training_args = training_args
self._recipe_args = recipe_args
self.datasets = {}
self.trainer = None
self.processor = None
self.parent_output_dir = self._training_args.output_dir
self._output_dir = self._training_args.output_dir
def populate_datasets(self, processor: Processor, add_labels: bool = True):
"""
Loads datasets for each flow based on data_args, stores a Dataset for each
enabled flow in self.datasets
:param processor: processor or tokenizer to use for dataset tokenization
:param add_labels: if True, add labels column to dataset splits
"""
if self._data_args.dataset is None:
self.processor = self._model_args.processor
logger.info(
"Running oneshot without calibration data. This is expected for "
"weight-only and dynamic quantization"
)
return
splits = self._data_args.splits
tokenized_datasets = {}
def _get_split_name(inp_str):
# strip out split name, for ex train[60%:] -> train
match = re.match(r"(\w*)\[.*\]", inp_str)
if match is not None:
return match.group(1)
return inp_str
if splits is None:
splits = {"all": None}
elif isinstance(splits, str):
splits = {_get_split_name(splits): splits}
elif isinstance(splits, List):
splits = {_get_split_name(s): s for s in splits}
# default to custom dataset if dataset provided isn't a string
registry_id = (
self._data_args.dataset
if isinstance(self._data_args.dataset, str)
else "custom"
)
for split_name, split_str in splits.items():
dataset = self._data_args.dataset
if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names:
# dataset is already tokenized
tokenized_datasets[split_name] = dataset
else:
# dataset needs to be tokenized
dataset_manager = TextGenerationDataset.load_from_registry(
registry_id,
data_args=self._data_args,
split=split_str,
processor=processor,
)
tokenized_datasets[split_name] = dataset_manager(add_labels=add_labels)
self.datasets = make_dataset_splits(
tokenized_datasets,
do_train=self._training_args.do_train,
do_eval=self._training_args.do_eval,
do_predict=self._training_args.do_predict,
do_oneshot=self._training_args.do_oneshot,
)
def get_dataset_split(self, split_name: str) -> Dataset:
"""
Retrieve a dataset split by name
:param split_name: name of dataset split to return
:return: dataset split labeled by split_name
"""
return self.datasets.get(split_name, None)
def one_shot(self, stage: Optional[str] = None):
"""
Run oneshot calibration on the active model
:param stage: which stage of the recipe to run, or None to run whole recipe
"""
logger.info("*** One Shot ***")
calib_data = None
if self.get_dataset_split("calibration") is not None:
calib_data = format_calibration_data(
tokenized_dataset=self.get_dataset_split("calibration"),
num_calibration_samples=self._data_args.num_calibration_samples,
do_shuffle=self._data_args.shuffle_calibration_samples,
collate_fn=self._data_args.data_collator,
accelerator=self.trainer.accelerator,
)
# if we don't run a forward pass after initializing the FSDP model for the
# first time, calls to summon_full_params will fail ¯\_(ツ)_/¯
if is_fsdp_model(self.trainer.model):
dummy_inp = dict(next(iter(calib_data)))
model_device = next(self.trainer.model.parameters()).device
dummy_inp = tensors_to_device(dummy_inp, model_device)
with torch.no_grad():
self.trainer.model(**dummy_inp)
self.trainer.accelerator.wait_for_everyone()
self.trainer.one_shot(calibration_data=calib_data, stage=stage)
def train(self, checkpoint: str, stage: Optional[str] = None):
"""
Run trainer's training loop on train_dataset, saving the resulting model to
output_dir
:param checkpoint: Optional checkpoint to resume from
:param stage: which stage of the recipe to run, or None to run whole recipe
"""
logger.info("*** Train ***")
train_result = self.trainer.train(
resume_from_checkpoint=checkpoint, stage=stage
)
metrics = train_result.metrics
metrics["train_samples"] = len(self.get_dataset_split("train"))
metrics["perplexity"] = math.exp(metrics["train_loss"])
self.trainer.log_metrics("train", metrics)
self.trainer.save_metrics("train", metrics)
# this includes saving the state, optimizer and scheduler
self.trainer.save_model(output_dir=self._output_dir)
def evaluate(self):
"""
Run trainer's evaluation loop on eval_dataset, logging the desired metrics
"""
logger.info("*** Evaluate ***")
metrics = self.trainer.evaluate(self.get_dataset_split("validation"))
metrics["eval_samples"] = len(self.get_dataset_split("validation"))
self.trainer.log_metrics("eval", metrics)
self.trainer.save_metrics("eval", metrics)
def predict(self):
"""
Run trainer's prediction loop on predict_dataset, logging the desired metrics
"""
logger.info("*** Predict ***")
results = self.trainer.predict(self.dataset["test"])
metrics = results.metrics
metrics["predict_samples"] = len(self.dataset["test"])
self.trainer.log_metrics("predict", metrics)
self.trainer.save_metrics("predict", metrics)
def run_sequential_stages(self, checkpoint: Optional[str] = None):
"""
Run the recipe stage by stage, allowing for alternating between one-shot and
finetuning flows. Optionally save the model output at the end of each stage
:param checkpoint: optional checkpoint to pick up a stage from
"""
recipe_obj = Recipe.create_instance(self._recipe_args.recipe)
with self.trainer.accelerator.main_process_first():
checkpoint_dir = self._model_args.model
completed_stages = get_completed_stages(checkpoint_dir)
self.trainer.accelerator.wait_for_everyone()
for stage in recipe_obj.stages:
# validate stage
stage_name = stage.group
run_type = stage.infer_run_type()
if not run_type:
raise ValueError(
f"a valid stage type ({[e.value for e in StageRunType]}) "
"must be provided in run_stages mode. Either add a run_type "
"attribute to each stage in the recipe or include it as part of "
"the stage name."
)
# just load structure if stage has already applied
if stage_name in completed_stages:
self.trainer.initialize_structure(stage=stage)
self.trainer.accelerator.wait_for_everyone()
continue
# setup checkpoint dir, TODO: this should be optional
self._output_dir = os.path.join(
self.parent_output_dir, "stage_" + stage_name
)
with self.trainer.accelerator.main_process_first():
if not os.path.exists(self._output_dir):
os.makedirs(self._output_dir)
save_completed_stages(self._output_dir, completed_stages)
self._training_args.output_dir = self._output_dir
# run stage
if run_type is StageRunType.ONESHOT:
self.one_shot(stage=stage_name)
elif run_type is StageRunType.TRAIN:
self.train(checkpoint=checkpoint, stage=stage_name)
checkpoint = None
# save model between stages
if (
self._training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
and self.trainer.accelerator.is_main_process
):
save_checkpoint(
save_path=self._output_dir,
model=self.trainer.model,
processor=self.processor,
save_safetensors=self._training_args.save_safetensors,
save_compressed=self._model_args.save_compressed,
)
self.trainer.accelerator.wait_for_everyone()
# save stage to checkpoint dir
if self.trainer.accelerator.is_main_process:
completed_stages.append(stage_name)
save_completed_stages(self._output_dir, completed_stages)
# setup for next stage
session = active_session()
session.reset_stage()
# synchronize and clean up memory
self.trainer.accelerator.wait_for_everyone()
self.trainer.model = get_session_model()
torch.cuda.empty_cache()
self.trainer.accelerator.free_memory()
self.trainer.accelerator.wait_for_everyone()