3434 'ReduceMetricsCallback' ,
3535 'CheckpointCallback' ,
3636 'CsvCallback' ,
37- 'WandbCallback' ,
37+ 'WandbInitCallback' ,
38+ 'WandbLoggerCallback' ,
3839 'TensorboardCallback' ,
3940 'CudaCallback' ,
4041]
@@ -91,7 +92,7 @@ class CbPriority(IntEnum):
9192 Default priorities for callbacks used by the pipeline and stage classes.
9293 """
9394
94- WANDB = - 200
95+ WANDB_INIT = - 200
9596 CHECKPOINT = - 190
9697 STAGE_TIMER = - 180
9798 DIAGNOSTICS = - 170
@@ -101,6 +102,7 @@ class CbPriority(IntEnum):
101102
102103 OBJECT_METHODS = 0
103104
105+ WANDB_LOGGER = 110
104106 CSV = 110
105107 TENSORBOARD = 110
106108 TABLE = 120
@@ -390,16 +392,17 @@ def post_epoch(self, stage: 'Stage'):
390392 writer .writerow (row )
391393
392394
393- class WandbCallback (Callback ):
395+ class WandbInitCallback (Callback ):
394396 """
395- A callback that logs metrics to Weights & Biases.
397+ A callback that initializes Weights & Biases and closes it at the end.
398+ This is separated from the WandbLoggerCallback to ensure it is called right at the beginning of training.
396399 """
397400
398401 def __init__ (self , project , entity , group , tags , startup_timeout , ** kwargs ):
399402 try :
400403 import wandb
401404 except ImportError :
402- raise ImportError ('wandb is required for the WandbCallback ' )
405+ raise ImportError ('wandb is required for the WandbInitCallback ' )
403406
404407 self .wandb = wandb
405408 self .project = project
@@ -421,15 +424,29 @@ def pre_run(self, pipe: 'Pipeline'):
421424 ** self .kwargs ,
422425 )
423426
424- def post_epoch (self , stage : 'Stage' ):
425- metrics = stage .history .last ()
426- self .wandb .log (metrics )
427-
428427 def cleanup (self , pipe , exc_type , exc_value , traceback ):
429428 if wandb_is_initialized ():
430429 self .wandb .finish (exit_code = 0 if exc_type is None else 1 )
431430
432431
432+ class WandbLoggerCallback (Callback ):
433+ """
434+ A callback that logs metrics to Weights & Biases.
435+ """
436+
437+ def __init__ (self ):
438+ try :
439+ import wandb
440+ except ImportError :
441+ raise ImportError ('wandb is required for the WandbLoggerCallback' )
442+
443+ self .wandb = wandb
444+
445+ def post_epoch (self , stage : 'Stage' ):
446+ metrics = stage .history .last ()
447+ self .wandb .log (metrics , commit = True , step = stage .current_epoch )
448+
449+
433450class TensorboardCallback (Callback ):
434451 """
435452 A callback that logs metrics to Tensorboard.
0 commit comments