20
20
from urllib .parse import urlparse
21
21
import json
22
22
import os
23
- import time
24
- from datetime import datetime
25
23
26
24
from datasets import load_dataset , Dataset
27
25
from datasets .distributed import split_dataset_by_node
34
32
TrainingArguments ,
35
33
DataCollatorForLanguageModeling ,
36
34
Trainer ,
37
- TrainerCallback ,
38
35
)
39
- import torch
40
- from torch .utils .tensorboard import SummaryWriter
41
- import torch .distributed as dist
42
-
43
-
44
- class CustomTensorBoardCallback (TrainerCallback ):
45
- def __init__ (self , log_dir = None , exclude_metrics = None ):
46
- self .exclude_metrics = exclude_metrics or []
47
- self .writer = None
48
- self .log_dir = log_dir
49
- self .epoch_start_time = None
50
- self .total_forward_time = 0
51
- self .total_backward_time = 0
52
- self .total_batches = 0
53
-
54
- def on_train_begin (self , args , state , control , ** kwargs ):
55
- # Initialize TensorBoard writer at the start of training, only for the main process.
56
- if dist .get_rank () == 0 :
57
- if self .log_dir is None :
58
- self .log_dir = args .logging_dir
59
- self .writer = SummaryWriter (log_dir = self .log_dir )
60
-
61
- def on_log (self , args , state , control , logs = None , ** kwargs ):
62
- # Aggregate metrics across all ranks and log only from rank 0.
63
- if logs is None :
64
- return
65
-
66
- aggregated_logs = {}
67
-
68
- for key , value in logs .items ():
69
- if key not in self .exclude_metrics : # Remove unwanted metrics
70
- tensor_value = torch .tensor (value , dtype = torch .float32 , device = "cuda" if torch .cuda .is_available () else "cpu" )
71
-
72
- # Aggregate across all ranks
73
- dist .all_reduce (tensor_value , op = dist .ReduceOp .SUM )
74
- tensor_value /= dist .get_world_size ()
75
-
76
- aggregated_logs [key ] = tensor_value .item ()
77
-
78
- if dist .get_rank () == 0 :
79
- for key , value in aggregated_logs .items ():
80
- self .writer .add_scalar (f"kfto-pytorch/{ key } " , value , state .global_step )
81
-
82
- self .writer .flush ()
83
-
84
- def on_epoch_begin (self , args , state , control , ** kwargs ):
85
- self .epoch_start_time = time .time ()
86
- print (f"Epoch { state .epoch + 1 } starting..." )
87
- if dist .get_rank () == 0 :
88
- self .writer .add_scalar ("kfto-pytorch/epoch" , state .epoch + 1 , state .global_step )
89
-
90
- def on_step_begin (self , args , state , control , ** kwargs ):
91
- self .start_forward_time = time .time ()
92
-
93
- def on_step_end (self , args , state , control , ** kwargs ):
94
- forward_time = time .time () - self .start_forward_time
95
- self .total_forward_time += forward_time
96
-
97
- start_backward_time = time .time ()
98
- torch .cuda .synchronize () # Wait for GPU operations to finish before timing
99
-
100
- backward_time = time .time () - start_backward_time
101
- self .total_backward_time += backward_time
102
-
103
- self .total_batches += 1
104
- avg_forward_time = self .total_forward_time / self .total_batches
105
- avg_backward_time = self .total_backward_time / self .total_batches
106
-
107
- gpu_memory_peak = torch .cuda .max_memory_allocated () / (1024 ** 2 ) # convert to mb
108
-
109
- if dist .get_rank () == 0 :
110
- self .writer .add_scalar ("kfto-pytorch/forward_time" , avg_forward_time , state .global_step )
111
- self .writer .add_scalar ("kfto-pytorch/backward_time" , avg_backward_time , state .global_step )
112
- self .writer .add_scalar ("kfto-pytorch/gpu_memory_peak_mb" , gpu_memory_peak , state .global_step )
113
-
114
- def on_epoch_end (self , args , state , control , ** kwargs ):
115
- epoch_end_time = time .time ()
116
- epoch_duration = epoch_end_time - self .epoch_start_time
117
-
118
- if dist .get_rank () == 0 :
119
- self .writer .add_scalar ("kfto-pytorch/epoch_duration" , epoch_duration , state .global_step )
120
-
121
- def on_evaluate (self , args , state , control , metrics , ** kwargs ):
122
- for key , value in metrics .items ():
123
- if key not in self .exclude_metrics and dist .get_rank () == 0 :
124
- self .writer .add_scalar (f"kfto-pytorch/{ key } " , value , state .global_step )
125
- self .writer .flush ()
36
+
126
37
127
38
# Configure logger.
128
39
log_formatter = logging .Formatter (
@@ -156,9 +67,6 @@ def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
156
67
# Freeze model parameters
157
68
for param in model .parameters ():
158
69
param .requires_grad = False
159
- # If running in a distributed setting, synchronize model parameters across workers.
160
- if dist .is_initialized ():
161
- dist .broadcast (param .data , src = 0 )
162
70
163
71
return model , tokenizer
164
72
@@ -234,17 +142,12 @@ def setup_peft_model(model, lora_config):
234
142
235
143
236
144
def train_model (model , transformer_type , train_data , eval_data , tokenizer , train_args ):
237
- # Allow for each run to be saved in a new directory to allow multiple runs to show on tensorboard
238
- log_dir = f"/mnt/logs/{ datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )} "
239
- # Exclude unwanted default metrics
240
- exclude_metrics = ['grad_norm' , 'total_flos' , 'train_runtime' , 'train_samples_per_second' , 'train_steps_per_second' ]
241
145
# Setup the Trainer.
242
146
trainer = Trainer (
243
147
model = model ,
244
148
train_dataset = train_data ,
245
149
eval_dataset = eval_data ,
246
150
args = train_args ,
247
- callbacks = [CustomTensorBoardCallback (log_dir = log_dir , exclude_metrics = exclude_metrics )]
248
151
)
249
152
250
153
# TODO (andreyvelich): Currently, data collator is supported only for casual LM Transformer.
@@ -259,11 +162,6 @@ def train_model(model, transformer_type, train_data, eval_data, tokenizer, train
259
162
260
163
# Train and save the model.
261
164
trainer .train ()
262
-
263
- # Using trainer.evaluate() for default eval metrics
264
- if eval_data is not None :
265
- eval_results = trainer .evaluate (eval_dataset = eval_data )
266
-
267
165
trainer .save_model ()
268
166
logger .info ("parallel_mode: '{0}'" .format (trainer .args .parallel_mode ))
269
167
logger .info ("is_model_parallel: '{0}'" .format (trainer .is_model_parallel ))
@@ -283,7 +181,6 @@ def parse_arguments():
283
181
parser .add_argument (
284
182
"--training_parameters" , help = "hugging face training parameters"
285
183
)
286
- parser .add_argument ("--log_dir" , type = str , default = "/mnt/logs" , help = "TensorBoard log directory" )
287
184
288
185
return parser .parse_args ()
289
186
0 commit comments