Skip to content

Commit bfd564c

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 2f26096 + 8e4d561 commit bfd564c

File tree

1 file changed

+1
-104
lines changed

1 file changed

+1
-104
lines changed

tests/kfto/resources/hf_llm_training.py

Lines changed: 1 addition & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from urllib.parse import urlparse
2121
import json
2222
import os
23-
import time
24-
from datetime import datetime
2523

2624
from datasets import load_dataset, Dataset
2725
from datasets.distributed import split_dataset_by_node
@@ -34,95 +32,8 @@
3432
TrainingArguments,
3533
DataCollatorForLanguageModeling,
3634
Trainer,
37-
TrainerCallback,
3835
)
39-
import torch
40-
from torch.utils.tensorboard import SummaryWriter
41-
import torch.distributed as dist
42-
43-
44-
class CustomTensorBoardCallback(TrainerCallback):
45-
def __init__(self, log_dir=None, exclude_metrics=None):
46-
self.exclude_metrics = exclude_metrics or []
47-
self.writer = None
48-
self.log_dir = log_dir
49-
self.epoch_start_time = None
50-
self.total_forward_time = 0
51-
self.total_backward_time = 0
52-
self.total_batches = 0
53-
54-
def on_train_begin(self, args, state, control, **kwargs):
55-
# Initialize TensorBoard writer at the start of training, only for the main process.
56-
if dist.get_rank() == 0:
57-
if self.log_dir is None:
58-
self.log_dir = args.logging_dir
59-
self.writer = SummaryWriter(log_dir=self.log_dir)
60-
61-
def on_log(self, args, state, control, logs=None, **kwargs):
62-
# Aggregate metrics across all ranks and log only from rank 0.
63-
if logs is None:
64-
return
65-
66-
aggregated_logs = {}
67-
68-
for key, value in logs.items():
69-
if key not in self.exclude_metrics: # Remove unwanted metrics
70-
tensor_value = torch.tensor(value, dtype=torch.float32, device="cuda" if torch.cuda.is_available() else "cpu")
71-
72-
# Aggregate across all ranks
73-
dist.all_reduce(tensor_value, op=dist.ReduceOp.SUM)
74-
tensor_value /= dist.get_world_size()
75-
76-
aggregated_logs[key] = tensor_value.item()
77-
78-
if dist.get_rank() == 0:
79-
for key, value in aggregated_logs.items():
80-
self.writer.add_scalar(f"kfto-pytorch/{key}", value, state.global_step)
81-
82-
self.writer.flush()
83-
84-
def on_epoch_begin(self, args, state, control, **kwargs):
85-
self.epoch_start_time = time.time()
86-
print(f"Epoch {state.epoch +1} starting...")
87-
if dist.get_rank() == 0:
88-
self.writer.add_scalar("kfto-pytorch/epoch", state.epoch +1, state.global_step)
89-
90-
def on_step_begin(self, args, state, control, **kwargs):
91-
self.start_forward_time = time.time()
92-
93-
def on_step_end(self, args, state, control, **kwargs):
94-
forward_time = time.time() - self.start_forward_time
95-
self.total_forward_time += forward_time
96-
97-
start_backward_time = time.time()
98-
torch.cuda.synchronize() # Wait for GPU operations to finish before timing
99-
100-
backward_time = time.time() - start_backward_time
101-
self.total_backward_time += backward_time
102-
103-
self.total_batches += 1
104-
avg_forward_time = self.total_forward_time / self.total_batches
105-
avg_backward_time = self.total_backward_time / self.total_batches
106-
107-
gpu_memory_peak = torch.cuda.max_memory_allocated() / (1024 ** 2) # convert to mb
108-
109-
if dist.get_rank() == 0:
110-
self.writer.add_scalar("kfto-pytorch/forward_time", avg_forward_time, state.global_step)
111-
self.writer.add_scalar("kfto-pytorch/backward_time", avg_backward_time, state.global_step)
112-
self.writer.add_scalar("kfto-pytorch/gpu_memory_peak_mb", gpu_memory_peak, state.global_step)
113-
114-
def on_epoch_end(self, args, state, control, **kwargs):
115-
epoch_end_time = time.time()
116-
epoch_duration = epoch_end_time - self.epoch_start_time
117-
118-
if dist.get_rank() == 0:
119-
self.writer.add_scalar("kfto-pytorch/epoch_duration", epoch_duration, state.global_step)
120-
121-
def on_evaluate(self, args, state, control, metrics, **kwargs):
122-
for key, value in metrics.items():
123-
if key not in self.exclude_metrics and dist.get_rank() == 0:
124-
self.writer.add_scalar(f"kfto-pytorch/{key}", value, state.global_step)
125-
self.writer.flush()
36+
12637

12738
# Configure logger.
12839
log_formatter = logging.Formatter(
@@ -156,9 +67,6 @@ def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
15667
# Freeze model parameters
15768
for param in model.parameters():
15869
param.requires_grad = False
159-
# If running in a distributed setting, synchronize model parameters across workers.
160-
if dist.is_initialized():
161-
dist.broadcast(param.data, src=0)
16270

16371
return model, tokenizer
16472

@@ -234,17 +142,12 @@ def setup_peft_model(model, lora_config):
234142

235143

236144
def train_model(model, transformer_type, train_data, eval_data, tokenizer, train_args):
237-
# Allow for each run to be saved in a new directory to allow multiple runs to show on tensorboard
238-
log_dir = f"/mnt/logs/{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
239-
# Exclude unwanted default metrics
240-
exclude_metrics = ['grad_norm', 'total_flos', 'train_runtime', 'train_samples_per_second', 'train_steps_per_second']
241145
# Setup the Trainer.
242146
trainer = Trainer(
243147
model=model,
244148
train_dataset=train_data,
245149
eval_dataset=eval_data,
246150
args=train_args,
247-
callbacks=[CustomTensorBoardCallback(log_dir=log_dir, exclude_metrics=exclude_metrics)]
248151
)
249152

250153
# TODO (andreyvelich): Currently, data collator is supported only for casual LM Transformer.
@@ -259,11 +162,6 @@ def train_model(model, transformer_type, train_data, eval_data, tokenizer, train
259162

260163
# Train and save the model.
261164
trainer.train()
262-
263-
# Using trainer.evaluate() for default eval metrics
264-
if eval_data is not None:
265-
eval_results = trainer.evaluate(eval_dataset=eval_data)
266-
267165
trainer.save_model()
268166
logger.info("parallel_mode: '{0}'".format(trainer.args.parallel_mode))
269167
logger.info("is_model_parallel: '{0}'".format(trainer.is_model_parallel))
@@ -283,7 +181,6 @@ def parse_arguments():
283181
parser.add_argument(
284182
"--training_parameters", help="hugging face training parameters"
285183
)
286-
parser.add_argument("--log_dir", type=str, default="/mnt/logs", help="TensorBoard log directory")
287184

288185
return parser.parse_args()
289186

0 commit comments

Comments
 (0)