1717
1818import os
1919
20- from typing import Any , Dict
20+ from typing import Any , Dict , Callable
2121
2222import tensorflow as tf
2323
2424from basic_pitch import visualize
2525
2626TENSORBOARD_LOGS_SUBDIR = "tensorboard_logs"
2727
28+
2829class VisualizeCallback (tf .keras .callbacks .Callback ):
2930 """
3031 Callback to run during training to create tensorboard visualizations per epoch.
@@ -49,10 +50,10 @@ def __init__(
4950 use_tf_function : bool = True ,
5051 ):
5152 super ().__init__ ()
52- self .train_ds = train_ds .take (max_batches ).prefatch (prefetch_batches )
53+ self .train_ds = train_ds .take (max_batches ).prefetch (prefetch_batches )
5354 self .validation_ds = validation_ds .take (max_batches ).prefetch (prefetch_batches )
5455 self .tensorboard_dir = os .path .join (tensorboard_dir , TENSORBOARD_LOGS_SUBDIR )
55- self .file_writer = tf .summary .create_file_writer (tensorboard_dir )
56+ self .file_writer = tf .summary .create_file_writer (self . tensorboard_dir )
5657 self .sonify = sonify
5758 self .contours = contours
5859 self .use_tf_function = use_tf_function
@@ -62,20 +63,18 @@ def __init__(
6263
6364 self ._predict_fn = None
6465
65- def set_module (self , model ) :
66+ def set_model (self , model : tf . keras . Model ) -> None :
6667 super ().set_model (model )
6768 if self .use_tf_function :
68- @tf_function
69- def fast_predict (inputs ):
69+ def fast_predict (inputs : tf .Tensor ) -> Any :
7070 return model (inputs , training = False )
71- self ._predict_fn = fast_predict
71+ self ._predict_fn = tf . function ( fast_predict )
7272 else :
7373 self ._predict_fn = model .predict
7474
75- def _predict (self , inputs ) :
75+ def _predict (self , inputs : tf . Tensor ) -> Any :
7676 if self ._predict_fn is not None :
7777 outputs = self ._predict_fn (inputs )
78- # tf.functions might output as a dict of tensors, convert to numpy:
7978 if isinstance (outputs , dict ):
8079 outputs = {k : v .numpy () if hasattr (v , "numpy" ) else v for k , v in outputs .items ()}
8180 return outputs
@@ -90,15 +89,16 @@ def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
9089 for batch in ds :
9190 inputs , targets = batch [:2 ]
9291 outputs = self ._predict (inputs )
92+ loss_val = logs .get (loss_key )
9393 visualize .visualize_transcription (
9494 self .file_writer ,
9595 stage ,
9696 inputs ,
9797 targets ,
9898 outputs ,
99- logs . get ( loss_key ) ,
99+ float ( loss_val ) if loss_val is not None else 0.0 ,
100100 epoch ,
101101 sonify = self .sonify ,
102- contours = self .contours
102+ contours = self .contours ,
103103 )
104- break
104+ break
0 commit comments