Skip to content

Commit e583b26

Browse files
committed
fix: missing last epoch in wandb, closes #44
1 parent e5e0d9b commit e583b26

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

dmlcloud/core/callbacks.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
'ReduceMetricsCallback',
3535
'CheckpointCallback',
3636
'CsvCallback',
37-
'WandbCallback',
37+
'WandbInitCallback',
38+
'WandbLoggerCallback',
3839
'TensorboardCallback',
3940
'CudaCallback',
4041
]
@@ -91,7 +92,7 @@ class CbPriority(IntEnum):
9192
Default priorities for callbacks used by the pipeline and stage classes.
9293
"""
9394

94-
WANDB = -200
95+
WANDB_INIT = -200
9596
CHECKPOINT = -190
9697
STAGE_TIMER = -180
9798
DIAGNOSTICS = -170
@@ -101,6 +102,7 @@ class CbPriority(IntEnum):
101102

102103
OBJECT_METHODS = 0
103104

105+
WANDB_LOGGER = 110
104106
CSV = 110
105107
TENSORBOARD = 110
106108
TABLE = 120
@@ -390,16 +392,17 @@ def post_epoch(self, stage: 'Stage'):
390392
writer.writerow(row)
391393

392394

393-
class WandbCallback(Callback):
395+
class WandbInitCallback(Callback):
394396
"""
395-
A callback that logs metrics to Weights & Biases.
397+
A callback that initializes Weights & Biases and closes it at the end.
398+
This is separated from the WandbLoggerCallback to ensure it is called right at the beginning of training.
396399
"""
397400

398401
def __init__(self, project, entity, group, tags, startup_timeout, **kwargs):
399402
try:
400403
import wandb
401404
except ImportError:
402-
raise ImportError('wandb is required for the WandbCallback')
405+
raise ImportError('wandb is required for the WandbInitCallback')
403406

404407
self.wandb = wandb
405408
self.project = project
@@ -421,15 +424,29 @@ def pre_run(self, pipe: 'Pipeline'):
421424
**self.kwargs,
422425
)
423426

424-
def post_epoch(self, stage: 'Stage'):
425-
metrics = stage.history.last()
426-
self.wandb.log(metrics)
427-
428427
def cleanup(self, pipe, exc_type, exc_value, traceback):
429428
if wandb_is_initialized():
430429
self.wandb.finish(exit_code=0 if exc_type is None else 1)
431430

432431

432+
class WandbLoggerCallback(Callback):
433+
"""
434+
A callback that logs metrics to Weights & Biases.
435+
"""
436+
437+
def __init__(self):
438+
try:
439+
import wandb
440+
except ImportError:
441+
raise ImportError('wandb is required for the WandbLoggerCallback')
442+
443+
self.wandb = wandb
444+
445+
def post_epoch(self, stage: 'Stage'):
446+
metrics = stage.history.last()
447+
self.wandb.log(metrics, commit=True, step=stage.current_epoch)
448+
449+
433450
class TensorboardCallback(Callback):
434451
"""
435452
A callback that logs metrics to Tensorboard.

dmlcloud/core/pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
DiagnosticsCallback,
1919
GitDiffCallback,
2020
TensorboardCallback,
21-
WandbCallback,
21+
WandbInitCallback,
22+
WandbLoggerCallback,
2223
)
2324
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
2425
from .distributed import broadcast_object, init, is_root, local_rank
@@ -197,15 +198,16 @@ def enable_wandb(
197198
import wandb # import now to avoid potential long import times later on # noqa
198199

199200
if is_root():
200-
callback = WandbCallback(
201+
init_callback = WandbInitCallback(
201202
project=project,
202203
entity=entity,
203204
group=group,
204205
tags=tags,
205206
startup_timeout=startup_timeout,
206207
**kwargs,
207208
)
208-
self.add_callback(callback, CbPriority.WANDB)
209+
self.add_callback(init_callback, CbPriority.WANDB_INIT)
210+
self.add_callback(WandbLoggerCallback(), CbPriority.WANDB_LOGGER)
209211

210212
self.wandb = True
211213

0 commit comments

Comments
 (0)