22from datetime import datetime
33from time import sleep
44
5-
65from lab import lab
76
87
@@ -14,9 +13,9 @@ def train():
1413 "experiment_name" : "alpha" ,
1514 "model_name" : "HuggingFaceTB/SmolLM-135M-Instruct" ,
1615 "dataset" : "Trelis/touch-rugby-rules" ,
17- "template_name" : "full -demo" ,
16+ "template_name" : "wandb -demo" ,
1817 "output_dir" : "./output" ,
19- "log_to_wandb" : False ,
18+ "log_to_wandb" : True , # Enable wandb logging for demo
2019 "_config" : {
2120 "dataset_name" : "Trelis/touch-rugby-rules" ,
2221 "lr" : 2e-5 ,
@@ -47,21 +46,71 @@ def train():
4746 lab .log ("Loaded dataset" )
4847
4948 # Report initial progress
50- lab .job . update_progress (10 )
49+ lab .update_progress (10 )
5150
5251 # Train the model
5352 lab .log ("Starting training..." )
5453 print ("Starting training" )
5554 for i in range (8 ):
5655 sleep (1 )
5756 lab .log (f"Iteration { i + 1 } /8" )
58- lab .job . update_progress (10 + (i + 1 ) * 10 )
57+ lab .update_progress (10 + (i + 1 ) * 10 )
5958 print (f"Iteration { i + 1 } /8" )
59+
60+ # Method 3: Initialize wandb during training (common pattern)
61+ if i == 3 : # Initialize wandb halfway through training
62+ try :
63+ import wandb
64+ if wandb .run is None :
65+ lab .log ("🚀 Initializing wandb during training..." )
66+ wandb .init (
67+ project = "transformerlab-test" ,
68+ name = f"test-run-{ lab .job .id } " ,
69+ config = training_config ["_config" ],
70+ )
71+ lab .log ("✅ Wandb initialized - URL should be auto-detected on next progress update!" )
72+ except ImportError :
73+ lab .log ("⚠️ Wandb not available" )
74+ except Exception as e :
75+ lab .log (f"⚠️ Error with wandb initialization: { e } " )
76+
77+ # Log metrics to wandb if available
78+ try :
79+ import wandb
80+ if wandb .run is not None :
81+ # Simulate training metrics
82+ fake_loss = 0.5 - (i + 1 ) * 0.05
83+ fake_accuracy = 0.6 + (i + 1 ) * 0.04
84+
85+ wandb .log ({
86+ "train/loss" : fake_loss ,
87+ "train/accuracy" : fake_accuracy ,
88+ "epoch" : i + 1
89+ })
90+
91+ lab .log (f"📈 Logged metrics to wandb: loss={ fake_loss :.3f} , accuracy={ fake_accuracy :.3f} " )
92+ except Exception :
93+ pass
6094
6195 # Calculate training time
6296 end_time = datetime .now ()
6397 training_duration = end_time - start_time
6498 lab .log (f"Training completed in { training_duration } " )
99+
100+ # Get the captured wandb URL from job data for reporting
101+ job_data = lab .job .get_job_data ()
102+ captured_wandb_url = job_data .get ("wandb_run_url" , "None" )
103+ lab .log (f"📋 Final wandb URL stored in job data: { captured_wandb_url } " )
104+
105+ # Finish wandb run if it was initialized
106+ try :
107+ import wandb
108+ if wandb .run is not None :
109+ wandb .finish ()
110+ lab .log ("✅ Wandb run finished" )
111+ except Exception :
112+ pass
113+
65114 print ("Complete" )
66115
67116 # Complete the job in TransformerLab via facade
@@ -74,6 +123,7 @@ def train():
74123 "output_dir" : os .path .join (
75124 training_config ["output_dir" ], f"final_model_{ lab .job .id } "
76125 ),
126+ "wandb_url" : captured_wandb_url ,
77127 }
78128
79129 except KeyboardInterrupt :
0 commit comments