Skip to content

Commit 2c4117c

Browse files
committed
feat (experimental): log step-wise metrics
1 parent b560cc8 commit 2c4117c

File tree

6 files changed

+117
-108
lines changed

6 files changed

+117
-108
lines changed

dmlcloud/core/callbacks.py

Lines changed: 80 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class TimedeltaFormatter:
5151
def __init__(self, microseconds=False):
5252
self.microseconds = microseconds
5353

54-
def __call__(self, value: torch.Tensor) -> str:
55-
delta = timedelta(seconds=value.item())
54+
def __call__(self, seconds: float) -> str:
55+
delta = timedelta(seconds=seconds)
5656
if not self.microseconds:
5757
delta -= timedelta(microseconds=delta.microseconds)
5858
return str(delta)
@@ -240,14 +240,15 @@ def pre_epoch(self, stage: 'Stage'):
240240
def post_epoch(self, stage: 'Stage'):
241241
self.epoch_end_time = datetime.now()
242242

243-
stage.log('misc/epoch', stage.current_epoch, prefixed=False)
244-
stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False)
245-
stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False)
243+
epoch_time = (stage.epoch_end_time - self.epoch_start_time).total_seconds()
244+
total_time = (stage.epoch_end_time - self.start_time).total_seconds()
245+
stage.log('misc/epoch_time', epoch_time, prefixed=False, log_step=False)
246+
stage.log('misc/total_time', total_time, prefixed=False, log_step=False)
246247

247248
if stage._run_epoch_overridden:
248249
average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1)
249250
eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1)
250-
stage.log('misc/eta', eta.total_seconds(), prefixed=False)
251+
stage.log('misc/eta', eta.total_seconds(), prefixed=False, log_step=False)
251252

252253

253254
class TableCallback(Callback):
@@ -345,12 +346,42 @@ class ReduceMetricsCallback(Callback):
345346
A callback that reduces the metrics at the end of each epoch and appends them to the history.
346347
"""
347348

348-
def post_epoch(self, stage: 'Stage'):
349+
def __init__(self, log_every_n_steps=50):
350+
self.log_every_n_steps = log_every_n_steps
351+
352+
def _reduce_epoch_metrics(self, stage):
349353
metrics = stage.metrics.reduce()
350354
stage.history.append_metrics(**metrics)
351-
stage.history.next_step()
355+
356+
def _reduce_step_metrics(self, stage):
357+
metrics = stage.step_metrics.reduce()
358+
stage.step_history.append_metrics(**metrics)
359+
360+
def post_epoch(self, stage: 'Stage'):
361+
stage.log('misc/epoch', stage.current_epoch, prefixed=False, reduction='max')
362+
self._reduce_epoch_metrics(stage)
352363
stage.step = 0 # Reset the step counter
353364

365+
def post_step(self, stage: 'Stage'):
366+
stage.log('misc/step', stage.global_step, prefixed=False, reduction='max')
367+
368+
if stage.global_step % self.log_every_n_steps == 0:
369+
self._reduce_step_metrics(stage)
370+
371+
stage.step += 1
372+
stage.global_step += 1
373+
374+
def post_stage(self, stage):
375+
has_unreduced_metrics = False
376+
for metric in stage.step_metrics.metrics.values():
377+
if metric.update_called:
378+
has_unreduced_metrics = True
379+
break
380+
381+
# need to check global_step > 0 to avoid reducing when finish_step() was never called once
382+
if has_unreduced_metrics and stage.global_step > 0:
383+
self._reduce_step_metrics(stage)
384+
354385

355386
class CheckpointCallback(Callback):
356387
"""
@@ -391,60 +422,61 @@ class CsvCallback(Callback):
391422
Saves metrics to a CSV file at the end of each epoch.
392423
"""
393424

394-
def __init__(self, path: Union[str, Path], append_stage_name: bool = False):
425+
def __init__(self, directory: Union[str, Path]):
395426
"""
396427
Initialize the callback with the given path.
397428
398429
Args:
399-
path (Union[str, Path]): The file path where the callback will operate.
400-
append_stage_name (bool, optional): Whether to append the stage name to the path. Defaults to False.
401-
"""
402-
self.path = Path(path)
403-
self.append_stage_name = append_stage_name
404-
405-
def csv_path(self, stage: 'Stage'):
430+
directory (Union[str, Path]): The path to the directory where the CSV files will be saved.
406431
"""
407-
Generate the CSV file path for the given stage.
408-
409-
If `append_stage_name` is True, the method appends the stage name to the file name.
410-
Otherwise, it returns the base path.
432+
self.directory = Path(directory)
433+
self.last_steps = {}
434+
435+
def _build_name(self, stage: 'Stage', prefix: str):
436+
duplicate_stages = [s for s in stage.pipe.stages if s.name == stage.name]
437+
idx = duplicate_stages.index(stage)
438+
if len(duplicate_stages) > 1:
439+
return self.directory / f'{prefix}_{stage.name}_{idx + 1}.csv'
440+
else:
441+
return self.directory / f'{prefix}_{stage.name}.csv'
411442

412-
Args:
413-
stage (Stage): The stage object containing the name to be appended.
443+
def epoch_path(self, stage: 'Stage'):
444+
return self._build_name(stage, 'epoch_metrics')
414445

415-
Returns:
416-
Path: The complete path to the CSV file.
417-
"""
418-
419-
if self.append_stage_name:
420-
duplicate_stages = [s for s in stage.pipe.stages if s.name == stage.name]
421-
idx = duplicate_stages.index(stage)
422-
if len(duplicate_stages) > 1:
423-
return self.path / f'metrics_{stage.name}_{idx + 1}.csv'
424-
else:
425-
return self.path / f'metrics_{stage.name}.csv'
426-
else:
427-
return self.path
446+
def step_path(self, stage: 'Stage'):
447+
return self._build_name(stage, 'step_metrics')
428448

429449
def pre_stage(self, stage: 'Stage'):
430450
# If for some reason we can't write to the file or it exists already, its better to fail early
431-
with open(self.csv_path(stage), 'x'):
451+
with open(self.epoch_path(stage), 'x'):
432452
pass
433453

434-
def post_epoch(self, stage: 'Stage'):
435-
with open(self.csv_path(stage), 'a') as f:
436-
writer = csv.writer(f)
454+
def _write_history(self, file, history, step_metric, step_name):
455+
writer = csv.writer(file)
456+
457+
metric_names = list(history.keys())
458+
metric_names.remove(step_metric)
437459

438-
metrics = stage.history.last()
460+
writer.writerow([step_name] + metric_names) # Header
461+
for row in history.rows():
462+
csv_row = [row[step_metric]] + [row[name] for name in metric_names]
463+
writer.writerow(csv_row)
439464

440-
# Write the header if the file is empty
441-
if f.tell() == 0:
442-
writer.writerow(['epoch'] + list(metrics))
465+
def _maybe_write_step_metrics(self, stage: 'Stage'):
466+
if stage.step_history.num_steps > self.last_steps.get(stage, 0):
467+
self.last_steps[stage] = stage.step_history.num_steps
468+
with open(self.step_path(stage), 'w') as f:
469+
self._write_history(f, stage.step_history, 'misc/step', 'step')
470+
471+
def post_epoch(self, stage: 'Stage'):
472+
with open(self.epoch_path(stage), 'w') as f:
473+
self._write_history(f, stage.history, 'misc/epoch', 'epoch')
443474

444-
row = [stage.current_epoch - 1] # epoch is already incremented
445-
for value in metrics.values():
446-
row.append(value.item())
447-
writer.writerow(row)
475+
def post_step(self, stage: 'Stage'):
476+
self._maybe_write_step_metrics(stage)
477+
478+
def post_stage(self, stage):
479+
self._maybe_write_step_metrics(stage) # edge case: last steps of training
448480

449481

450482
class WandbInitCallback(Callback):
@@ -523,7 +555,7 @@ def pre_run(self, pipe):
523555
def post_epoch(self, stage: 'Stage'):
524556
metrics = stage.history.last()
525557
for key, value in metrics.items():
526-
self.writer.add_scalar(key, value.item(), stage.current_epoch)
558+
self.writer.add_scalar(key, value, stage.current_epoch)
527559

528560
def cleanup(self, pipe, exc_type, exc_value, traceback):
529561
if self.writer is not None:

dmlcloud/core/metrics.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class TrainingHistory:
3434

3535
def __init__(self):
3636
self.num_steps = 0
37-
self._current_values = {}
3837
self._metrics = {}
3938
self._dtypes = {}
4039

@@ -65,6 +64,10 @@ def values(self):
6564
def items(self):
6665
return [(name, self[name]) for name in self._metrics]
6766

67+
def rows(self):
68+
for i in range(self.num_steps):
69+
yield {name: self._metrics[name][i] for name in self._metrics}
70+
6871
def append_metric(self, name: str, value: Union[ArrayLike, Any]):
6972
"""
7073
Adds a value for a metric at the current step.
@@ -76,14 +79,6 @@ def append_metric(self, name: str, value: Union[ArrayLike, Any]):
7679
if name in self._current_values:
7780
raise ValueError(f'Metric {name} already has a value for step {self.num_steps}')
7881

79-
if name not in self._metrics and self.num_steps > 0:
80-
raise ValueError(f'Cannot add metric {name} after the first step')
81-
82-
if isinstance(value, torch.Tensor):
83-
value = value.detach().to('cpu', non_blocking=True)
84-
85-
self._current_values[name] = value
86-
8782
def append_metrics(self, **metrics):
8883
"""
8984
Adds multiple metrics at the current step.
@@ -92,28 +87,16 @@ def append_metrics(self, **metrics):
9287
**metrics: The metrics to add.
9388
"""
9489
for name, value in metrics.items():
95-
self.append_metric(name, value)
96-
97-
def next_step(self):
98-
"""
99-
Advances the step counter.
100-
"""
101-
102-
for name in self._metrics:
103-
if name not in self._current_values:
104-
raise ValueError(f'Metric {name} does not have a value for step {self.num_steps}')
105-
106-
for name, value in self._current_values.items():
107-
if type(value) == ArrayLike: # noqa
108-
value = np.as_array(value)
90+
dtype = value.dtype if type(value) == ArrayLike else object # noqa
91+
if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray):
92+
value = value.item()
10993

11094
if name not in self._metrics:
111-
self._metrics[name] = [value]
112-
self._dtypes[name] = value.dtype if type(value) == ArrayLike else object # noqa
95+
self._metrics[name] = ([None] * self.num_steps) + [value]
96+
self._dtypes[name] = dtype
11397
else:
11498
self._metrics[name].append(value)
11599

116-
self._current_values = {}
117100
self.num_steps += 1
118101

119102
def last(self) -> dict[str, Any]:
@@ -126,16 +109,6 @@ def last(self) -> dict[str, Any]:
126109

127110
return {name: values[-1] for name, values in self._metrics.items()}
128111

129-
def current(self) -> dict[str, Any]:
130-
"""
131-
Returns the current, but not yet saved, value for each metric.
132-
133-
Returns:
134-
dict[str, Any]: The current value for each metric.
135-
"""
136-
137-
return {name: self._current_values[name] for name in self._current_values}
138-
139112
def min(self) -> dict[str, min_return_type]:
140113
"""
141114
Returns a namedtuple (value, step) containing the minimum value and the corresponding step for each metric across all steps.
@@ -180,10 +153,12 @@ def log(self, name: str, value: Any, reduction: str = 'mean', **kwargs):
180153
if not torch.is_tensor(value):
181154
value = torch.tensor(value)
182155
value = value.cpu()
156+
dtype = value.dtype
183157

184158
if name not in self.metrics:
185159
if reduction == 'mean':
186160
metric = torchmetrics.MeanMetric(**kwargs)
161+
dtype = torch.float32
187162
elif reduction == 'sum':
188163
metric = torchmetrics.SumMetric(**kwargs)
189164
elif reduction == 'min':
@@ -192,15 +167,20 @@ def log(self, name: str, value: Any, reduction: str = 'mean', **kwargs):
192167
metric = torchmetrics.MaxMetric(**kwargs)
193168
elif reduction == 'cat':
194169
metric = torchmetrics.CatMetric(**kwargs)
195-
self.add_metric(name, metric.cpu())
170+
metric = metric.cpu().set_dtype(dtype)
171+
self.add_metric(name, metric)
196172

197173
self.metrics[name].update(value)
198174

199-
def reduce(self):
175+
def reduce(self, reset: bool = True):
200176
values = {}
201177
for name, metric in self.metrics.items():
202-
values[name] = metric.compute()
203-
metric.reset()
178+
if metric.update_called:
179+
values[name] = metric.compute()
180+
if reset:
181+
metric.reset()
182+
else:
183+
values[name] = None
204184
return values
205185

206186
def clear(self):

dmlcloud/core/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def enable_checkpointing(
229229

230230
if is_root():
231231
self.add_callback(CheckpointCallback(self.run_dir), CbPriority.CHECKPOINT)
232-
self.add_callback(CsvCallback(self.run_dir, append_stage_name=True), CbPriority.CSV)
232+
self.add_callback(CsvCallback(self.run_dir), CbPriority.CSV)
233233
self.add_callback(TensorboardCallback(self.run_dir), CbPriority.TENSORBOARD)
234234

235235
def enable_wandb(

dmlcloud/core/stage.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def __init__(self, name: str = None, epochs: int | None = 1):
6060
self.pipe = None # set by the pipeline
6161

6262
self.history = TrainingHistory()
63+
self.step_history = TrainingHistory()
6364
self.metrics = Tracker()
65+
self.step_metrics = Tracker()
6466

6567
self.step = 0
6668
self.global_step = 0
@@ -165,10 +167,12 @@ def add_callback(self, callback: 'Callback', priority: int = 1):
165167
"""
166168
self.callbacks.append(callback, priority)
167169

168-
def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True):
170+
def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True, log_step: bool = True):
169171
if prefixed and self.metric_prefix:
170172
name = f'{self.metric_prefix}/{name}'
171173
self.metrics.log(name, value, reduction)
174+
if log_step:
175+
self.step_metrics.log(name, value, reduction)
172176

173177
def add_metric(self, name, metric):
174178
metric = metric.to(self.device)
@@ -299,8 +303,6 @@ def next_epoch(self):
299303

300304
def finish_step(self):
301305
self._post_step()
302-
self.step += 1
303-
self.global_step += 1
304306

305307
def run_epoch(self):
306308
"""

examples/mnist.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import dmlcloud as dml
44
import torch
5-
import torchmetrics
65
from torch import nn
76
from torch.utils.data import DataLoader
87
from torchvision import datasets, transforms
@@ -55,9 +54,6 @@ def pre_stage(self):
5554
self.add_column('[Val] Loss', 'val/loss', color='cyan')
5655
self.add_column('[Val] Acc.', 'val/accuracy', formatter=lambda acc: f'{100 * acc:.2f}%', color='cyan')
5756

58-
self.train_acc = self.add_metric('train/accuracy', torchmetrics.Accuracy('multiclass', num_classes=10))
59-
self.val_acc = self.add_metric('val/accuracy', torchmetrics.Accuracy('multiclass', num_classes=10))
60-
6157
# The run_epoch method is called once per epoch
6258
def run_epoch(self):
6359
self._train_epoch()
@@ -78,8 +74,9 @@ def _train_epoch(self):
7874
self.optimizer.step()
7975

8076
self.log('loss', loss)
81-
# self.log('accuracy', (output.argmax(1) == target).float().mean())
82-
self.train_acc(output, target)
77+
self.log('accuracy', (output.argmax(1) == target).float().mean())
78+
79+
self.finish_step() # optional, but useful to get step-wise metrics
8380

8481
@torch.no_grad()
8582
def _val_epoch(self):
@@ -93,8 +90,7 @@ def _val_epoch(self):
9390
loss = self.loss(output, target)
9491

9592
self.log('loss', loss)
96-
# self.log('accuracy', (output.argmax(1) == target).float().mean())
97-
self.val_acc(output, target)
93+
self.log('accuracy', (output.argmax(1) == target).float().mean())
9894

9995

10096
def main():

0 commit comments

Comments
 (0)