77import os
88from datetime import datetime
99from time import sleep
10+ from transformers import TrainerCallback , TrainerControl , TrainerState , TrainingArguments
1011
1112from lab import lab
1213
1516login (token = os .getenv ("HF_TOKEN" ))
1617
1718
19+ class LabCallback (TrainerCallback ):
20+ """Custom callback to update TransformerLab progress and save checkpoints"""
21+
22+ def __init__ (self ):
23+ self .training_started = False
24+ self .total_steps = None
25+
26+ def on_train_begin (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
27+ """Called when training begins"""
28+ lab .log ("🚀 Training started with HuggingFace Trainer" )
29+ self .training_started = True
30+ if state .max_steps and state .max_steps > 0 :
31+ self .total_steps = state .max_steps
32+ else :
33+ # Estimate steps if not provided
34+ self .total_steps = 1000
35+
36+ def on_step_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
37+ """Called after each training step"""
38+ if self .total_steps :
39+ progress = int ((state .global_step / self .total_steps ) * 100 )
40+ progress = min (progress , 95 ) # Keep some buffer for final operations
41+ lab .update_progress (progress )
42+
43+ # Log training metrics if available
44+ if state .log_history :
45+ latest_log = state .log_history [- 1 ]
46+ if "loss" in latest_log :
47+ lab .log (f"Step { state .global_step } : loss={ latest_log ['loss' ]:.4f} " )
48+
49+ def on_save (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
50+ """Called when a checkpoint is saved"""
51+ lab .log (f"💾 Checkpoint saved at step { state .global_step } " )
52+
53+ # Attempt to save the checkpoint using lab's checkpoint mechanism
54+ if hasattr (args , 'output_dir' ):
55+ checkpoint_dir = None
56+ # Find the most recent checkpoint
57+ if os .path .exists (args .output_dir ):
58+ checkpoints = [d for d in os .listdir (args .output_dir ) if d .startswith ('checkpoint-' )]
59+ if checkpoints :
60+ # Sort by checkpoint number
61+ checkpoints .sort (key = lambda x : int (x .split ('-' )[1 ]))
62+ latest_checkpoint = checkpoints [- 1 ]
63+ checkpoint_dir = os .path .join (args .output_dir , latest_checkpoint )
64+
65+ # Save checkpoint to TransformerLab
66+ try :
67+ saved_path = lab .save_checkpoint (checkpoint_dir , f"checkpoint-{ state .global_step } " )
68+ lab .log (f"✅ Saved checkpoint to TransformerLab: { saved_path } " )
69+ except Exception as e :
70+ lab .log (f"⚠️ Could not save checkpoint to TransformerLab: { e } " )
71+
72+ def on_epoch_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
73+ """Called at the end of each epoch"""
74+ if state .epoch :
75+ lab .log (f"📊 Completed epoch { int (state .epoch )} / { args .num_train_epochs } " )
76+
77+ def on_train_end (self , args : TrainingArguments , state : TrainerState , control : TrainerControl , ** kwargs ):
78+ """Called when training ends"""
79+ lab .log ("✅ Training completed successfully" )
80+ lab .update_progress (95 )
81+
82+
1883def train_with_trl (quick_test = True ):
1984 """Training function using HuggingFace SFTTrainer with automatic wandb detection
2085
@@ -49,8 +114,6 @@ def train_with_trl(quick_test=True):
49114 "max_steps" : 3 if quick_test else - 1 , # Limit steps for quick test
50115 "report_to" : ["wandb" ], # Enable wandb reporting in SFTTrainer
51116 "dataloader_num_workers" : 0 , # Avoid multiprocessing issues
52- "remove_unused_columns" : False ,
53- "push_to_hub" : False ,
54117 },
55118 }
56119
@@ -126,7 +189,7 @@ def train_with_trl(quick_test=True):
126189 try :
127190 from trl import SFTTrainer , SFTConfig
128191
129- # SFTConfig with wandb reporting
192+ # SFTConfig with wandb reporting and automatic checkpoint saving
130193 training_args = SFTConfig (
131194 output_dir = training_config ["output_dir" ],
132195 num_train_epochs = training_config ["_config" ]["num_train_epochs" ],
@@ -144,33 +207,27 @@ def train_with_trl(quick_test=True):
144207 remove_unused_columns = False ,
145208 push_to_hub = False ,
146209 dataset_text_field = "text" , # Move dataset_text_field to SFTConfig
210+ # Enable automatic checkpoint saving
211+ save_total_limit = 3 , # Keep only the last 3 checkpoints to save disk space
212+ save_strategy = "steps" , # Save checkpoints every save_steps
213+ load_best_model_at_end = False ,
147214 )
148215
149- # Create SFTTrainer - this will initialize wandb if report_to includes "wandb"
216+ # Create custom callback for TransformerLab integration
217+ transformerlab_callback = LabCallback ()
218+
150219 trainer = SFTTrainer (
151220 model = model ,
152221 args = training_args ,
153222 train_dataset = dataset ["train" ],
154223 processing_class = tokenizer ,
224+ callbacks = [transformerlab_callback ], # Add our custom callback
155225 )
156226
157227 lab .log ("✅ SFTTrainer created - wandb should be initialized automatically!" )
158228 lab .log ("🔍 Checking for wandb URL detection..." )
159229
160- except ImportError :
161- lab .log ("⚠️ TRL not available, using basic training simulation" )
162- # Simulate wandb initialization for testing
163- try :
164- import wandb
165- if wandb .run is None :
166- wandb .init (
167- project = "transformerlab-trl-test" ,
168- name = f"trl-sim-{ lab .job .id } " ,
169- config = training_config ["_config" ]
170- )
171- lab .log ("✅ Simulated wandb initialization for testing" )
172- except Exception :
173- pass
230+
174231 except Exception as e :
175232 lab .log (f"Error setting up SFTTrainer: { e } " )
176233 lab .finish ("Training failed - trainer setup error" )
@@ -182,37 +239,20 @@ def train_with_trl(quick_test=True):
182239 if quick_test :
183240 lab .log ("🚀 Quick test mode: Initializing SFTTrainer and testing wandb detection..." )
184241 else :
185- lab .log ("Starting training with SFTTrainer ..." )
242+ lab .log ("Starting training..." )
186243
187244 try :
188245 if 'trainer' in locals ():
189246 # Real training with SFTTrainer
190247 if quick_test :
191- lab .log ("✅ SFTTrainer initialized successfully - testing wandb detection ..." )
248+ lab .log ("✅ SFTTrainer initialized successfully..." )
192249 # Just test that wandb is initialized, don't actually train
193250 lab .log ("Quick test: Skipping actual training, just testing wandb URL detection" )
194251 else :
252+ # Training will automatically save checkpoints via the callback
195253 trainer .train ()
196254 lab .log ("✅ Training completed with SFTTrainer" )
197255
198- # Save checkpoints and artifacts after full training
199- lab .log ("Saving training checkpoints and artifacts..." )
200-
201- # Create 5 fake checkpoints to simulate training progression
202- for epoch in range (1 , 6 ):
203- checkpoint_file = os .path .join (training_config ["output_dir" ], f"checkpoint_epoch_{ epoch } .txt" )
204- with open (checkpoint_file , "w" ) as f :
205- f .write (f"Training checkpoint for epoch { epoch } \n " )
206- f .write (f"Model state: epoch_{ epoch } \n " )
207- f .write (f"Loss: { 0.5 - epoch * 0.08 :.3f} \n " )
208- f .write (f"Accuracy: { 0.6 + epoch * 0.08 :.3f} \n " )
209- f .write (f"Timestamp: { datetime .now ()} \n " )
210- f .write (f"Training step: { epoch * 100 } \n " )
211- f .write (f"Learning rate: { training_config ['_config' ]['lr' ]} \n " )
212-
213- # Save checkpoint using lab facade
214- saved_checkpoint_path = lab .save_checkpoint (checkpoint_file , f"epoch_{ epoch } _checkpoint.txt" )
215- lab .log (f"Saved checkpoint: { saved_checkpoint_path } " )
216256
217257 # Create 2 additional artifacts for full training
218258 # Artifact 1: Training progress summary
0 commit comments