Skip to content

Commit 1b7458e

Browse files
authored
Merge pull request #55 from transformerlab/fix/update-trl-script-real-checkpoints
Update TRL script to have real progress updates and checkpoints
2 parents ea0260e + 18a12e6 commit 1b7458e

File tree

1 file changed

+78
-38
lines changed

1 file changed

+78
-38
lines changed

scripts/examples/trl_train_script.py

Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
from datetime import datetime
99
from time import sleep
10+
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
1011

1112
from lab import lab
1213

@@ -15,6 +16,70 @@
1516
login(token=os.getenv("HF_TOKEN"))
1617

1718

19+
class LabCallback(TrainerCallback):
20+
"""Custom callback to update TransformerLab progress and save checkpoints"""
21+
22+
def __init__(self):
23+
self.training_started = False
24+
self.total_steps = None
25+
26+
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
27+
"""Called when training begins"""
28+
lab.log("🚀 Training started with HuggingFace Trainer")
29+
self.training_started = True
30+
if state.max_steps and state.max_steps > 0:
31+
self.total_steps = state.max_steps
32+
else:
33+
# Estimate steps if not provided
34+
self.total_steps = 1000
35+
36+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
37+
"""Called after each training step"""
38+
if self.total_steps:
39+
progress = int((state.global_step / self.total_steps) * 100)
40+
progress = min(progress, 95) # Keep some buffer for final operations
41+
lab.update_progress(progress)
42+
43+
# Log training metrics if available
44+
if state.log_history:
45+
latest_log = state.log_history[-1]
46+
if "loss" in latest_log:
47+
lab.log(f"Step {state.global_step}: loss={latest_log['loss']:.4f}")
48+
49+
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
50+
"""Called when a checkpoint is saved"""
51+
lab.log(f"💾 Checkpoint saved at step {state.global_step}")
52+
53+
# Attempt to save the checkpoint using lab's checkpoint mechanism
54+
if hasattr(args, 'output_dir'):
55+
checkpoint_dir = None
56+
# Find the most recent checkpoint
57+
if os.path.exists(args.output_dir):
58+
checkpoints = [d for d in os.listdir(args.output_dir) if d.startswith('checkpoint-')]
59+
if checkpoints:
60+
# Sort by checkpoint number
61+
checkpoints.sort(key=lambda x: int(x.split('-')[1]))
62+
latest_checkpoint = checkpoints[-1]
63+
checkpoint_dir = os.path.join(args.output_dir, latest_checkpoint)
64+
65+
# Save checkpoint to TransformerLab
66+
try:
67+
saved_path = lab.save_checkpoint(checkpoint_dir, f"checkpoint-{state.global_step}")
68+
lab.log(f"✅ Saved checkpoint to TransformerLab: {saved_path}")
69+
except Exception as e:
70+
lab.log(f"⚠️ Could not save checkpoint to TransformerLab: {e}")
71+
72+
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
73+
"""Called at the end of each epoch"""
74+
if state.epoch:
75+
lab.log(f"📊 Completed epoch {int(state.epoch)} / {args.num_train_epochs}")
76+
77+
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
78+
"""Called when training ends"""
79+
lab.log("✅ Training completed successfully")
80+
lab.update_progress(95)
81+
82+
1883
def train_with_trl(quick_test=True):
1984
"""Training function using HuggingFace SFTTrainer with automatic wandb detection
2085
@@ -49,8 +114,6 @@ def train_with_trl(quick_test=True):
49114
"max_steps": 3 if quick_test else -1, # Limit steps for quick test
50115
"report_to": ["wandb"], # Enable wandb reporting in SFTTrainer
51116
"dataloader_num_workers": 0, # Avoid multiprocessing issues
52-
"remove_unused_columns": False,
53-
"push_to_hub": False,
54117
},
55118
}
56119

@@ -126,7 +189,7 @@ def train_with_trl(quick_test=True):
126189
try:
127190
from trl import SFTTrainer, SFTConfig
128191

129-
# SFTConfig with wandb reporting
192+
# SFTConfig with wandb reporting and automatic checkpoint saving
130193
training_args = SFTConfig(
131194
output_dir=training_config["output_dir"],
132195
num_train_epochs=training_config["_config"]["num_train_epochs"],
@@ -144,33 +207,27 @@ def train_with_trl(quick_test=True):
144207
remove_unused_columns=False,
145208
push_to_hub=False,
146209
dataset_text_field="text", # Move dataset_text_field to SFTConfig
210+
# Enable automatic checkpoint saving
211+
save_total_limit=3, # Keep only the last 3 checkpoints to save disk space
212+
save_strategy="steps", # Save checkpoints every save_steps
213+
load_best_model_at_end=False,
147214
)
148215

149-
# Create SFTTrainer - this will initialize wandb if report_to includes "wandb"
216+
# Create custom callback for TransformerLab integration
217+
transformerlab_callback = LabCallback()
218+
150219
trainer = SFTTrainer(
151220
model=model,
152221
args=training_args,
153222
train_dataset=dataset["train"],
154223
processing_class=tokenizer,
224+
callbacks=[transformerlab_callback], # Add our custom callback
155225
)
156226

157227
lab.log("✅ SFTTrainer created - wandb should be initialized automatically!")
158228
lab.log("🔍 Checking for wandb URL detection...")
159229

160-
except ImportError:
161-
lab.log("⚠️ TRL not available, using basic training simulation")
162-
# Simulate wandb initialization for testing
163-
try:
164-
import wandb
165-
if wandb.run is None:
166-
wandb.init(
167-
project="transformerlab-trl-test",
168-
name=f"trl-sim-{lab.job.id}",
169-
config=training_config["_config"]
170-
)
171-
lab.log("✅ Simulated wandb initialization for testing")
172-
except Exception:
173-
pass
230+
174231
except Exception as e:
175232
lab.log(f"Error setting up SFTTrainer: {e}")
176233
lab.finish("Training failed - trainer setup error")
@@ -182,37 +239,20 @@ def train_with_trl(quick_test=True):
182239
if quick_test:
183240
lab.log("🚀 Quick test mode: Initializing SFTTrainer and testing wandb detection...")
184241
else:
185-
lab.log("Starting training with SFTTrainer...")
242+
lab.log("Starting training...")
186243

187244
try:
188245
if 'trainer' in locals():
189246
# Real training with SFTTrainer
190247
if quick_test:
191-
lab.log("✅ SFTTrainer initialized successfully - testing wandb detection...")
248+
lab.log("✅ SFTTrainer initialized successfully...")
192249
# Just test that wandb is initialized, don't actually train
193250
lab.log("Quick test: Skipping actual training, just testing wandb URL detection")
194251
else:
252+
# Training will automatically save checkpoints via the callback
195253
trainer.train()
196254
lab.log("✅ Training completed with SFTTrainer")
197255

198-
# Save checkpoints and artifacts after full training
199-
lab.log("Saving training checkpoints and artifacts...")
200-
201-
# Create 5 fake checkpoints to simulate training progression
202-
for epoch in range(1, 6):
203-
checkpoint_file = os.path.join(training_config["output_dir"], f"checkpoint_epoch_{epoch}.txt")
204-
with open(checkpoint_file, "w") as f:
205-
f.write(f"Training checkpoint for epoch {epoch}\n")
206-
f.write(f"Model state: epoch_{epoch}\n")
207-
f.write(f"Loss: {0.5 - epoch * 0.08:.3f}\n")
208-
f.write(f"Accuracy: {0.6 + epoch * 0.08:.3f}\n")
209-
f.write(f"Timestamp: {datetime.now()}\n")
210-
f.write(f"Training step: {epoch * 100}\n")
211-
f.write(f"Learning rate: {training_config['_config']['lr']}\n")
212-
213-
# Save checkpoint using lab facade
214-
saved_checkpoint_path = lab.save_checkpoint(checkpoint_file, f"epoch_{epoch}_checkpoint.txt")
215-
lab.log(f"Saved checkpoint: {saved_checkpoint_path}")
216256

217257
# Create 2 additional artifacts for full training
218258
# Artifact 1: Training progress summary

0 commit comments

Comments
 (0)