Skip to content

Commit 03146d9

Browse files
committed
removed the VisualizeCallback change
1 parent 51913a9 commit 03146d9

File tree

1 file changed

+23
-50
lines changed

1 file changed

+23
-50
lines changed

basic_pitch/callbacks.py

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import os
1919

20-
from typing import Any, Dict, Callable
20+
from typing import Any, Dict
2121

2222
import tensorflow as tf
2323

@@ -27,6 +27,7 @@
2727

2828

2929
class VisualizeCallback(tf.keras.callbacks.Callback):
30+
# TODO RACHEL make this WAY faster
3031
"""
3132
Callback to run during training to create tensorboard visualizations per epoch.
3233
@@ -45,60 +46,32 @@ def __init__(
4546
tensorboard_dir: str,
4647
sonify: bool,
4748
contours: bool,
48-
max_batches: int = 2,
49-
prefetch_batches: int = 2,
50-
use_tf_function: bool = True,
5149
):
5250
super().__init__()
53-
self.train_ds = train_ds.take(max_batches).prefetch(prefetch_batches)
54-
self.validation_ds = validation_ds.take(max_batches).prefetch(prefetch_batches)
51+
self.train_iter = iter(train_ds)
52+
self.validation_iter = iter(validation_ds)
5553
self.tensorboard_dir = os.path.join(tensorboard_dir, TENSORBOARD_LOGS_SUBDIR)
56-
self.file_writer = tf.summary.create_file_writer(self.tensorboard_dir)
54+
self.file_writer = tf.summary.create_file_writer(tensorboard_dir)
5755
self.sonify = sonify
5856
self.contours = contours
59-
self.use_tf_function = use_tf_function
60-
61-
self.train_iter = iter(self.train_ds)
62-
self.validation_iter = iter(self.validation_ds)
63-
64-
self._predict_fn = None
65-
66-
def set_model(self, model: tf.keras.Model) -> None:
67-
super().set_model(model)
68-
if self.use_tf_function:
69-
def fast_predict(inputs: tf.Tensor) -> Any:
70-
return model(inputs, training=False)
71-
self._predict_fn = tf.function(fast_predict)
72-
else:
73-
self._predict_fn = model.predict
74-
75-
def _predict(self, inputs: tf.Tensor) -> Any:
76-
if self._predict_fn is not None:
77-
outputs = self._predict_fn(inputs)
78-
if isinstance(outputs, dict):
79-
outputs = {k: v.numpy() if hasattr(v, "numpy") else v for k, v in outputs.items()}
80-
return outputs
81-
else:
82-
return self.model.predict(inputs)
8357

8458
def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
85-
for stage, ds, loss_key in [
86-
("train", self.train_ds, "loss"),
87-
("validation", self.validation_ds, "val_loss"),
59+
# the first two outputs of generator needs to be the input and the targets
60+
train_inputs, train_targets = next(self.train_iter)[:2]
61+
validation_inputs, validation_targets = next(self.validation_iter)[:2]
62+
for stage, inputs, targets, loss in [
63+
("train", train_inputs, train_targets, logs["loss"]),
64+
("validation", validation_inputs, validation_targets, logs["val_loss"]),
8865
]:
89-
for batch in ds:
90-
inputs, targets = batch[:2]
91-
outputs = self._predict(inputs)
92-
loss_val = logs.get(loss_key)
93-
visualize.visualize_transcription(
94-
self.file_writer,
95-
stage,
96-
inputs,
97-
targets,
98-
outputs,
99-
float(loss_val) if loss_val is not None else 0.0,
100-
epoch,
101-
sonify=self.sonify,
102-
contours=self.contours,
103-
)
104-
break
66+
outputs = self.model.predict(inputs)
67+
visualize.visualize_transcription(
68+
self.file_writer,
69+
stage,
70+
inputs,
71+
targets,
72+
outputs,
73+
loss,
74+
epoch,
75+
sonify=self.sonify,
76+
contours=self.contours,
77+
)

0 commit comments

Comments
 (0)