@@ -51,8 +51,8 @@ class TimedeltaFormatter:
5151 def __init__ (self , microseconds = False ):
5252 self .microseconds = microseconds
5353
54- def __call__ (self , value : torch . Tensor ) -> str :
55- delta = timedelta (seconds = value . item () )
54+ def __call__ (self , seconds : float ) -> str :
55+ delta = timedelta (seconds = seconds )
5656 if not self .microseconds :
5757 delta -= timedelta (microseconds = delta .microseconds )
5858 return str (delta )
@@ -240,14 +240,15 @@ def pre_epoch(self, stage: 'Stage'):
240240 def post_epoch (self , stage : 'Stage' ):
241241 self .epoch_end_time = datetime .now ()
242242
243- stage .log ('misc/epoch' , stage .current_epoch , prefixed = False )
244- stage .log ('misc/epoch_time' , (stage .epoch_end_time - self .epoch_start_time ).total_seconds (), prefixed = False )
245- stage .log ('misc/total_time' , (stage .epoch_end_time - self .start_time ).total_seconds (), prefixed = False )
243+ epoch_time = (stage .epoch_end_time - self .epoch_start_time ).total_seconds ()
244+ total_time = (stage .epoch_end_time - self .start_time ).total_seconds ()
245+ stage .log ('misc/epoch_time' , epoch_time , prefixed = False , log_step = False )
246+ stage .log ('misc/total_time' , total_time , prefixed = False , log_step = False )
246247
247248 if stage ._run_epoch_overridden :
248249 average_epoch_time = (stage .epoch_end_time - self .start_time ) / (stage .current_epoch + 1 )
249250 eta = average_epoch_time * (stage .max_epochs - stage .current_epoch - 1 )
250- stage .log ('misc/eta' , eta .total_seconds (), prefixed = False )
251+ stage .log ('misc/eta' , eta .total_seconds (), prefixed = False , log_step = False )
251252
252253
253254class TableCallback (Callback ):
@@ -345,12 +346,42 @@ class ReduceMetricsCallback(Callback):
345346 A callback that reduces the metrics at the end of each epoch and appends them to the history.
346347 """
347348
348- def post_epoch (self , stage : 'Stage' ):
349+ def __init__ (self , log_every_n_steps = 50 ):
350+ self .log_every_n_steps = log_every_n_steps
351+
352+ def _reduce_epoch_metrics (self , stage ):
349353 metrics = stage .metrics .reduce ()
350354 stage .history .append_metrics (** metrics )
351- stage .history .next_step ()
355+
356+ def _reduce_step_metrics (self , stage ):
357+ metrics = stage .step_metrics .reduce ()
358+ stage .step_history .append_metrics (** metrics )
359+
360+ def post_epoch (self , stage : 'Stage' ):
361+ stage .log ('misc/epoch' , stage .current_epoch , prefixed = False , reduction = 'max' )
362+ self ._reduce_epoch_metrics (stage )
352363 stage .step = 0 # Reset the step counter
353364
365+ def post_step (self , stage : 'Stage' ):
366+ stage .log ('misc/step' , stage .global_step , prefixed = False , reduction = 'max' )
367+
368+ if stage .global_step % self .log_every_n_steps == 0 :
369+ self ._reduce_step_metrics (stage )
370+
371+ stage .step += 1
372+ stage .global_step += 1
373+
374+ def post_stage (self , stage ):
375+ has_unreduced_metrics = False
376+ for metric in stage .step_metrics .metrics .values ():
377+ if metric .update_called :
378+ has_unreduced_metrics = True
379+ break
380+
381+ # need to check global_step > 0 to avoid reducing when finish_step() was never called once
382+ if has_unreduced_metrics and stage .global_step > 0 :
383+ self ._reduce_step_metrics (stage )
384+
354385
355386class CheckpointCallback (Callback ):
356387 """
@@ -391,60 +422,61 @@ class CsvCallback(Callback):
391422 Saves metrics to a CSV file at the end of each epoch.
392423 """
393424
394- def __init__ (self , path : Union [str , Path ], append_stage_name : bool = False ):
425+ def __init__ (self , directory : Union [str , Path ]):
395426 """
396427 Initialize the callback with the given path.
397428
398429 Args:
399- path (Union[str, Path]): The file path where the callback will operate.
400- append_stage_name (bool, optional): Whether to append the stage name to the path. Defaults to False.
401- """
402- self .path = Path (path )
403- self .append_stage_name = append_stage_name
404-
405- def csv_path (self , stage : 'Stage' ):
430+ directory (Union[str, Path]): The path to the directory where the CSV files will be saved.
406431 """
407- Generate the CSV file path for the given stage.
408-
409- If `append_stage_name` is True, the method appends the stage name to the file name.
410- Otherwise, it returns the base path.
432+ self .directory = Path (directory )
433+ self .last_steps = {}
434+
435+ def _build_name (self , stage : 'Stage' , prefix : str ):
436+ duplicate_stages = [s for s in stage .pipe .stages if s .name == stage .name ]
437+ idx = duplicate_stages .index (stage )
438+ if len (duplicate_stages ) > 1 :
439+ return self .directory / f'{ prefix } _{ stage .name } _{ idx + 1 } .csv'
440+ else :
441+ return self .directory / f'{ prefix } _{ stage .name } .csv'
411442
412- Args :
413- stage (Stage): The stage object containing the name to be appended.
443+ def epoch_path ( self , stage : 'Stage' ) :
444+ return self . _build_name ( stage , 'epoch_metrics' )
414445
415- Returns:
416- Path: The complete path to the CSV file.
417- """
418-
419- if self .append_stage_name :
420- duplicate_stages = [s for s in stage .pipe .stages if s .name == stage .name ]
421- idx = duplicate_stages .index (stage )
422- if len (duplicate_stages ) > 1 :
423- return self .path / f'metrics_{ stage .name } _{ idx + 1 } .csv'
424- else :
425- return self .path / f'metrics_{ stage .name } .csv'
426- else :
427- return self .path
446+ def step_path (self , stage : 'Stage' ):
447+ return self ._build_name (stage , 'step_metrics' )
428448
429449 def pre_stage (self , stage : 'Stage' ):
430450 # If for some reason we can't write to the file or it exists already, its better to fail early
431- with open (self .csv_path (stage ), 'x' ):
451+ with open (self .epoch_path (stage ), 'x' ):
432452 pass
433453
434- def post_epoch (self , stage : 'Stage' ):
435- with open (self .csv_path (stage ), 'a' ) as f :
436- writer = csv .writer (f )
454+ def _write_history (self , file , history , step_metric , step_name ):
455+ writer = csv .writer (file )
456+
457+ metric_names = list (history .keys ())
458+ metric_names .remove (step_metric )
437459
438- metrics = stage .history .last ()
460+ writer .writerow ([step_name ] + metric_names ) # Header
461+ for row in history .rows ():
462+ csv_row = [row [step_metric ]] + [row [name ] for name in metric_names ]
463+ writer .writerow (csv_row )
439464
440- # Write the header if the file is empty
441- if f .tell () == 0 :
442- writer .writerow (['epoch' ] + list (metrics ))
465+ def _maybe_write_step_metrics (self , stage : 'Stage' ):
466+ if stage .step_history .num_steps > self .last_steps .get (stage , 0 ):
467+ self .last_steps [stage ] = stage .step_history .num_steps
468+ with open (self .step_path (stage ), 'w' ) as f :
469+ self ._write_history (f , stage .step_history , 'misc/step' , 'step' )
470+
471+ def post_epoch (self , stage : 'Stage' ):
472+ with open (self .epoch_path (stage ), 'w' ) as f :
473+ self ._write_history (f , stage .history , 'misc/epoch' , 'epoch' )
443474
444- row = [stage .current_epoch - 1 ] # epoch is already incremented
445- for value in metrics .values ():
446- row .append (value .item ())
447- writer .writerow (row )
475+ def post_step (self , stage : 'Stage' ):
476+ self ._maybe_write_step_metrics (stage )
477+
478+ def post_stage (self , stage ):
479+ self ._maybe_write_step_metrics (stage ) # edge case: last steps of training
448480
449481
450482class WandbInitCallback (Callback ):
@@ -523,7 +555,7 @@ def pre_run(self, pipe):
523555 def post_epoch (self , stage : 'Stage' ):
524556 metrics = stage .history .last ()
525557 for key , value in metrics .items ():
526- self .writer .add_scalar (key , value . item () , stage .current_epoch )
558+ self .writer .add_scalar (key , value , stage .current_epoch )
527559
528560 def cleanup (self , pipe , exc_type , exc_value , traceback ):
529561 if self .writer is not None :
0 commit comments