1717
1818import os
1919
20- from typing import Any , Dict , Callable
20+ from typing import Any , Dict
2121
2222import tensorflow as tf
2323
2727
2828
2929class VisualizeCallback (tf .keras .callbacks .Callback ):
30+ # TODO RACHEL make this WAY faster
3031 """
3132 Callback to run during training to create tensorboard visualizations per epoch.
3233
@@ -45,60 +46,32 @@ def __init__(
4546 tensorboard_dir : str ,
4647 sonify : bool ,
4748 contours : bool ,
48- max_batches : int = 2 ,
49- prefetch_batches : int = 2 ,
50- use_tf_function : bool = True ,
5149 ):
5250 super ().__init__ ()
53- self .train_ds = train_ds . take ( max_batches ). prefetch ( prefetch_batches )
54- self .validation_ds = validation_ds . take ( max_batches ). prefetch ( prefetch_batches )
51+ self .train_iter = iter ( train_ds )
52+ self .validation_iter = iter ( validation_ds )
5553 self .tensorboard_dir = os .path .join (tensorboard_dir , TENSORBOARD_LOGS_SUBDIR )
56- self .file_writer = tf .summary .create_file_writer (self . tensorboard_dir )
54+ self .file_writer = tf .summary .create_file_writer (tensorboard_dir )
5755 self .sonify = sonify
5856 self .contours = contours
59- self .use_tf_function = use_tf_function
60-
61- self .train_iter = iter (self .train_ds )
62- self .validation_iter = iter (self .validation_ds )
63-
64- self ._predict_fn = None
65-
66- def set_model (self , model : tf .keras .Model ) -> None :
67- super ().set_model (model )
68- if self .use_tf_function :
69- def fast_predict (inputs : tf .Tensor ) -> Any :
70- return model (inputs , training = False )
71- self ._predict_fn = tf .function (fast_predict )
72- else :
73- self ._predict_fn = model .predict
74-
75- def _predict (self , inputs : tf .Tensor ) -> Any :
76- if self ._predict_fn is not None :
77- outputs = self ._predict_fn (inputs )
78- if isinstance (outputs , dict ):
79- outputs = {k : v .numpy () if hasattr (v , "numpy" ) else v for k , v in outputs .items ()}
80- return outputs
81- else :
82- return self .model .predict (inputs )
8357
8458 def on_epoch_end (self , epoch : int , logs : Dict [Any , Any ]) -> None :
85- for stage , ds , loss_key in [
86- ("train" , self .train_ds , "loss" ),
87- ("validation" , self .validation_ds , "val_loss" ),
59+ # the first two outputs of generator needs to be the input and the targets
60+ train_inputs , train_targets = next (self .train_iter )[:2 ]
61+ validation_inputs , validation_targets = next (self .validation_iter )[:2 ]
62+ for stage , inputs , targets , loss in [
63+ ("train" , train_inputs , train_targets , logs ["loss" ]),
64+ ("validation" , validation_inputs , validation_targets , logs ["val_loss" ]),
8865 ]:
89- for batch in ds :
90- inputs , targets = batch [:2 ]
91- outputs = self ._predict (inputs )
92- loss_val = logs .get (loss_key )
93- visualize .visualize_transcription (
94- self .file_writer ,
95- stage ,
96- inputs ,
97- targets ,
98- outputs ,
99- float (loss_val ) if loss_val is not None else 0.0 ,
100- epoch ,
101- sonify = self .sonify ,
102- contours = self .contours ,
103- )
104- break
66+ outputs = self .model .predict (inputs )
67+ visualize .visualize_transcription (
68+ self .file_writer ,
69+ stage ,
70+ inputs ,
71+ targets ,
72+ outputs ,
73+ loss ,
74+ epoch ,
75+ sonify = self .sonify ,
76+ contours = self .contours ,
77+ )
0 commit comments