Skip to content

Commit 51913a9

Browse files
committed
small formatting fixes
1 parent 0a551bf commit 51913a9

File tree

7 files changed

+19
-12
lines changed

7 files changed

+19
-12
lines changed

basic_pitch/callbacks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717

1818
import os
1919

20-
from typing import Any, Dict
20+
from typing import Any, Dict, Callable
2121

2222
import tensorflow as tf
2323

2424
from basic_pitch import visualize
2525

2626
TENSORBOARD_LOGS_SUBDIR = "tensorboard_logs"
2727

28+
2829
class VisualizeCallback(tf.keras.callbacks.Callback):
2930
"""
3031
Callback to run during training to create tensorboard visualizations per epoch.
@@ -49,10 +50,10 @@ def __init__(
4950
use_tf_function: bool = True,
5051
):
5152
super().__init__()
52-
self.train_ds = train_ds.take(max_batches).prefatch(prefetch_batches)
53+
self.train_ds = train_ds.take(max_batches).prefetch(prefetch_batches)
5354
self.validation_ds = validation_ds.take(max_batches).prefetch(prefetch_batches)
5455
self.tensorboard_dir = os.path.join(tensorboard_dir, TENSORBOARD_LOGS_SUBDIR)
55-
self.file_writer = tf.summary.create_file_writer(tensorboard_dir)
56+
self.file_writer = tf.summary.create_file_writer(self.tensorboard_dir)
5657
self.sonify = sonify
5758
self.contours = contours
5859
self.use_tf_function = use_tf_function
@@ -62,20 +63,18 @@ def __init__(
6263

6364
self._predict_fn = None
6465

65-
def set_module(self, model):
66+
def set_model(self, model: tf.keras.Model) -> None:
6667
super().set_model(model)
6768
if self.use_tf_function:
68-
@tf_function
69-
def fast_predict(inputs):
69+
def fast_predict(inputs: tf.Tensor) -> Any:
7070
return model(inputs, training=False)
71-
self._predict_fn = fast_predict
71+
self._predict_fn = tf.function(fast_predict)
7272
else:
7373
self._predict_fn = model.predict
7474

75-
def _predict(self, inputs):
75+
def _predict(self, inputs: tf.Tensor) -> Any:
7676
if self._predict_fn is not None:
7777
outputs = self._predict_fn(inputs)
78-
# tf.functions might output as a dict of tensors, convert to numpy:
7978
if isinstance(outputs, dict):
8079
outputs = {k: v.numpy() if hasattr(v, "numpy") else v for k, v in outputs.items()}
8180
return outputs
@@ -90,15 +89,16 @@ def on_epoch_end(self, epoch: int, logs: Dict[Any, Any]) -> None:
9089
for batch in ds:
9190
inputs, targets = batch[:2]
9291
outputs = self._predict(inputs)
92+
loss_val = logs.get(loss_key)
9393
visualize.visualize_transcription(
9494
self.file_writer,
9595
stage,
9696
inputs,
9797
targets,
9898
outputs,
99-
logs.get(loss_key),
99+
float(loss_val) if loss_val is not None else 0.0,
100100
epoch,
101101
sonify=self.sonify,
102-
contours=self.contours
102+
contours=self.contours,
103103
)
104-
break
104+
break

basic_pitch/commandline_printing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
TF_LOG_LEVEL_KEY = "TF_CPP_MIN_LOG_LEVEL"
2525
TF_LOG_LEVEL_NO_WARNINGS_VALUE = "3"
26+
DEFAULT_PRINT_INDENT = " "
27+
2628
s_print_lock = threading.Lock()
2729
OUTPUT_EMOJIS = {
2830
"MIDI": "💅",

basic_pitch/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from enum import Enum
2121

22+
2223
SEMITONES_PER_OCTAVE = 12 # for frequency bin calculations
2324

2425
FFT_HOP = 256

basic_pitch/layers/nnaudio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DEFAULT_CQT_OUTPUT_FORMAT = "Magnitude"
4242
DEFAULT_LOW_PASS_TRANSITION_BANDWIDTH = 0.001
4343

44+
4445
def create_lowpass_filter(
4546
band_center: float = DEFAULT_BAND_CENTER,
4647
kernel_length: int = DEFAULT_KERNEL_LENGTH,

basic_pitch/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ONSET_STRIDES_1 = (1, 3)
5252
ONSET_KERNEL_SIZE_2 = (3, 3)
5353

54+
5455
def transcription_loss(y_true: tf.Tensor, y_pred: tf.Tensor, label_smoothing: float) -> tf.Tensor:
5556
"""Really a binary cross entropy loss. Used to calculate the loss between the predicted
5657
posteriorgrams and the ground truth matrices.

basic_pitch/nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
SEMITONES_PER_OCTAVE = 12
2626

27+
2728
class HarmonicStacking(tf.keras.layers.Layer):
2829
"""Harmonic stacking layer
2930

basic_pitch/note_creation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
MIDI_VELOCITY_SCALE = 127
4949
PITCH_BEND_SCALE = 4096
5050

51+
5152
def model_output_to_notes(
5253
output: Dict[str, np.array],
5354
onset_thresh: float,

0 commit comments

Comments
 (0)