Skip to content

Commit 96e15a7

Browse files
authored
Merge pull request #172 from saltytine/main
Improve maintainability by removing magic numbers and implementing a few fixes
2 parents f423902 + c29b01b commit 96e15a7

File tree

8 files changed

+126
-59
lines changed

8 files changed

+126
-59
lines changed

basic_pitch/callbacks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
from basic_pitch import visualize
2525

26+
TENSORBOARD_LOGS_SUBDIR = "tensorboard_logs"
27+
2628

2729
class VisualizeCallback(tf.keras.callbacks.Callback):
2830
# TODO RACHEL make this WAY faster
@@ -48,7 +50,7 @@ def __init__(
4850
super().__init__()
4951
self.train_iter = iter(train_ds)
5052
self.validation_iter = iter(validation_ds)
51-
self.tensorboard_dir = os.path.join(tensorboard_dir, "tensorboard_logs")
53+
self.tensorboard_dir = os.path.join(tensorboard_dir, TENSORBOARD_LOGS_SUBDIR)
5254
self.file_writer = tf.summary.create_file_writer(tensorboard_dir)
5355
self.sonify = sonify
5456
self.contours = contours

basic_pitch/commandline_printing.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
TF_LOG_LEVEL_KEY = "TF_CPP_MIN_LOG_LEVEL"
2525
TF_LOG_LEVEL_NO_WARNINGS_VALUE = "3"
26+
DEFAULT_PRINT_INDENT = " "
27+
2628
s_print_lock = threading.Lock()
2729
OUTPUT_EMOJIS = {
2830
"MIDI": "💅",
@@ -39,8 +41,7 @@ def generating_file_message(output_type: str) -> None:
3941
output_type: string indicating which kind of file is being generated
4042
4143
"""
42-
print(f"\n\n Creating {output_type.replace('_', ' ').lower()}...")
43-
44+
print(f"\n\n{DEFAULT_PRINT_INDENT}Creating {output_type.replace('_', ' ').lower()}...")
4445

4546
def file_saved_confirmation(output_type: str, save_path: Union[pathlib.Path, str]) -> None:
4647
"""Print a confirmation that the file was saved succesfully
@@ -50,8 +51,8 @@ def file_saved_confirmation(output_type: str, save_path: Union[pathlib.Path, str
5051
save_path: The path to output file.
5152
5253
"""
53-
print(f" {OUTPUT_EMOJIS[output_type]} Saved to {save_path}")
54-
54+
emoji = OUTPUT_EMOJIS.get(output_type, "")
55+
print(f"{DEFAULT_PRINT_INDENT}{emoji} Saved to {save_path}")
5556

5657
def failed_to_save(output_type: str, save_path: Union[pathlib.Path, str]) -> None:
5758
"""Print a failure to save message
@@ -63,13 +64,14 @@ def failed_to_save(output_type: str, save_path: Union[pathlib.Path, str]) -> Non
6364
"""
6465
print(f"\n🚨 Failed to save {output_type.replace('_', ' ').lower()} to {save_path} \n")
6566

66-
6767
@contextmanager
6868
def no_tf_warnings() -> Iterator[None]:
6969
"""
7070
Supress tensorflow warnings in this context
7171
"""
7272
tf_logging_level = os.environ.get(TF_LOG_LEVEL_KEY, TF_LOG_LEVEL_NO_WARNINGS_VALUE)
7373
os.environ[TF_LOG_LEVEL_KEY] = TF_LOG_LEVEL_NO_WARNINGS_VALUE
74-
yield
75-
os.environ[TF_LOG_LEVEL_KEY] = tf_logging_level
74+
try:
75+
yield
76+
finally:
77+
os.environ[TF_LOG_LEVEL_KEY] = tf_logging_level

basic_pitch/constants.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
from enum import Enum
2121

22+
23+
SEMITONES_PER_OCTAVE = 12 # for frequency bin calculations
24+
2225
FFT_HOP = 256
2326
N_FFT = 8 * FFT_HOP
2427

@@ -54,7 +57,7 @@
5457

5558

5659
def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int) -> np.array:
57-
d = 2.0 ** (1.0 / (12 * bins_per_semitone))
60+
d = 2.0 ** (1.0 / (SEMITONES_PER_OCTAVE * bins_per_semitone))
5861
bin_freqs = base_frequency * d ** np.arange(bins_per_semitone * n_semitones)
5962
return bin_freqs
6063

basic_pitch/inference.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,14 @@ def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float3
181181
}
182182

183183

184+
DEFAULT_ONSET_THRESHOLD = 0.5
185+
DEFAULT_FRAME_THRESHOLD = 0.3
186+
DEFAULT_MINIMUM_NOTE_LENGTH_MS = 127.7
187+
DEFAULT_MINIMUM_MIDI_TEMPO = 120
188+
DEFAULT_SONIFICATION_SAMPLERATE = 44100
189+
DEFAULT_OVERLAPPING_FRAMES = 30
190+
DEFAULT_MIDI_VELOCITY_SCALE = 127
191+
184192
def window_audio_file(
185193
audio_original: npt.NDArray[np.float32], hop_size: int
186194
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float]]]:
@@ -284,7 +292,7 @@ def run_inference(
284292
model = Model(model_or_model_path)
285293

286294
# overlap 30 frames
287-
n_overlapping_frames = 30
295+
n_overlapping_frames = DEFAULT_OVERLAPPING_FRAMES
288296
overlap_len = n_overlapping_frames * FFT_HOP
289297
hop_size = AUDIO_N_SAMPLES - overlap_len
290298

@@ -405,7 +413,7 @@ def save_note_events(
405413
writer = csv.writer(fhandle, delimiter=",")
406414
writer.writerow(["start_time_s", "end_time_s", "pitch_midi", "velocity", "pitch_bend"])
407415
for start_time, end_time, note_number, amplitude, pitch_bend in note_events:
408-
row = [start_time, end_time, note_number, int(np.round(127 * amplitude))]
416+
row = [start_time, end_time, note_number, int(np.round(DEFAULT_MIDI_VELOCITY_SCALE * amplitude))]
409417
if pitch_bend:
410418
row.extend(pitch_bend)
411419
writer.writerow(row)
@@ -414,15 +422,15 @@ def save_note_events(
414422
def predict(
415423
audio_path: Union[pathlib.Path, str],
416424
model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH,
417-
onset_threshold: float = 0.5,
418-
frame_threshold: float = 0.3,
419-
minimum_note_length: float = 127.70,
425+
onset_threshold: float = DEFAULT_ONSET_THRESHOLD,
426+
frame_threshold: float = DEFAULT_FRAME_THRESHOLD,
427+
minimum_note_length: float = DEFAULT_MINIMUM_NOTE_LENGTH_MS,
420428
minimum_frequency: Optional[float] = None,
421429
maximum_frequency: Optional[float] = None,
422430
multiple_pitch_bends: bool = False,
423431
melodia_trick: bool = True,
424432
debug_file: Optional[pathlib.Path] = None,
425-
midi_tempo: float = 120,
433+
midi_tempo: float = DEFAULT_MINIMUM_MIDI_TEMPO,
426434
) -> Tuple[
427435
Dict[str, np.array],
428436
pretty_midi.PrettyMIDI,
@@ -497,16 +505,16 @@ def predict_and_save(
497505
save_model_outputs: bool,
498506
save_notes: bool,
499507
model_or_model_path: Union[Model, str, pathlib.Path],
500-
onset_threshold: float = 0.5,
501-
frame_threshold: float = 0.3,
502-
minimum_note_length: float = 127.70,
508+
onset_threshold: float = DEFAULT_ONSET_THRESHOLD,
509+
frame_threshold: float = DEFAULT_FRAME_THRESHOLD,
510+
minimum_note_length: float = DEFAULT_MINIMUM_NOTE_LENGTH_MS,
503511
minimum_frequency: Optional[float] = None,
504512
maximum_frequency: Optional[float] = None,
505513
multiple_pitch_bends: bool = False,
506514
melodia_trick: bool = True,
507515
debug_file: Optional[pathlib.Path] = None,
508-
sonification_samplerate: int = 44100,
509-
midi_tempo: float = 120,
516+
sonification_samplerate: int = DEFAULT_SONIFICATION_SAMPLERATE,
517+
midi_tempo: float = DEFAULT_MINIMUM_MIDI_TEMPO,
510518
) -> None:
511519
"""Make a prediction and save the results to file.
512520

basic_pitch/layers/nnaudio.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,27 @@
2626

2727
import scipy.signal
2828

29+
DEFAULT_BAND_CENTER = 0.5
30+
DEFAULT_KERNEL_LENGTH = 256
31+
DEFAULT_TRANSITION_BANDWIDTH = 0.03
32+
DEFAULT_DTYPE = tf.float32
33+
DEFAULT_WINDOW_BANDWIDTH = 1.5
34+
DEFAULT_CQT_HOP_LENGTH = 512
35+
DEFAULT_CQT_FMIN = 32.70
36+
DEFAULT_CQT_N_BINS = 84
37+
DEFAULT_CQT_BINS_PER_OCTAVE = 12
38+
DEFAULT_CQT_BASIS_NORM = 1
39+
DEFAULT_CQT_WINDOW = "hann"
40+
DEFAULT_CQT_PAD_MODE = "reflect"
41+
DEFAULT_CQT_OUTPUT_FORMAT = "Magnitude"
42+
DEFAULT_LOW_PASS_TRANSITION_BANDWIDTH = 0.001
43+
2944

3045
def create_lowpass_filter(
31-
band_center: float = 0.5,
32-
kernel_length: int = 256,
33-
transition_bandwidth: float = 0.03,
34-
dtype: tf.dtypes.DType = tf.float32,
46+
band_center: float = DEFAULT_BAND_CENTER,
47+
kernel_length: int = DEFAULT_KERNEL_LENGTH,
48+
transition_bandwidth: float = DEFAULT_TRANSITION_BANDWIDTH,
49+
dtype: tf.dtypes.DType = DEFAULT_DTYPE,
3550
) -> np.ndarray:
3651
"""
3752
Calculate the highest frequency we need to preserve and the lowest frequency we allow
@@ -106,15 +121,15 @@ def get_early_downsample_params(
106121
) -> Tuple[Union[float, int], int, float, np.array, bool]:
107122
"""Compute downsampling parameters used for early downsampling"""
108123

109-
window_bandwidth = 1.5 # for hann window
124+
window_bandwidth = DEFAULT_WINDOW_BANDWIDTH # for hann window
110125
filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
111126
sr, hop_length, downsample_factor = early_downsample(sr, hop_length, n_octaves, sr // 2, filter_cutoff)
112127
if downsample_factor != 1:
113128
earlydownsample = True
114129
early_downsample_filter = create_lowpass_filter(
115130
band_center=1 / downsample_factor,
116-
kernel_length=256,
117-
transition_bandwidth=0.03,
131+
kernel_length=DEFAULT_KERNEL_LENGTH,
132+
transition_bandwidth=DEFAULT_TRANSITION_BANDWIDTH,
118133
dtype=dtype,
119134
)
120135
else:
@@ -455,19 +470,19 @@ class CQT2010v2(tf.keras.layers.Layer):
455470
def __init__(
456471
self,
457472
sr: int = 22050,
458-
hop_length: int = 512,
459-
fmin: float = 32.70,
473+
hop_length: int = DEFAULT_CQT_HOP_LENGTH,
474+
fmin: float = DEFAULT_CQT_FMIN,
460475
fmax: Optional[float] = None,
461-
n_bins: int = 84,
476+
n_bins: int = DEFAULT_CQT_N_BINS,
462477
filter_scale: int = 1,
463-
bins_per_octave: int = 12,
478+
bins_per_octave: int = DEFAULT_CQT_BINS_PER_OCTAVE,
464479
norm: bool = True,
465-
basis_norm: int = 1,
466-
window: str = "hann",
467-
pad_mode: str = "reflect",
480+
basis_norm: int = DEFAULT_CQT_BASIS_NORM,
481+
window: str = DEFAULT_CQT_WINDOW,
482+
pad_mode: str = DEFAULT_CQT_PAD_MODE,
468483
earlydownsample: bool = True,
469484
trainable: bool = False,
470-
output_format: str = "Magnitude",
485+
output_format: str = DEFAULT_CQT_OUTPUT_FORMAT,
471486
match_torch_exactly: bool = True,
472487
):
473488
super().__init__()
@@ -516,7 +531,11 @@ def build(self, input_shape: tf.TensorShape) -> None:
516531
# This will be used to calculate filter_cutoff and creating CQT kernels
517532
Q = float(self.filter_scale) / (2 ** (1 / self.bins_per_octave) - 1)
518533

519-
self.lowpass_filter = create_lowpass_filter(band_center=0.5, kernel_length=256, transition_bandwidth=0.001)
534+
self.lowpass_filter = create_lowpass_filter(
535+
band_center=DEFAULT_BAND_CENTER,
536+
kernel_length=DEFAULT_KERNEL_LENGTH,
537+
transition_bandwidth=DEFAULT_LOW_PASS_TRANSITION_BANDWIDTH,
538+
)
520539

521540
# Calculate num of filter requires for the kernel
522541
# n_octaves determines how many resampling requires for the CQT

basic_pitch/models.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@
3535

3636
MAX_N_SEMITONES = int(np.floor(12.0 * np.log2(0.5 * AUDIO_SAMPLE_RATE / ANNOTATIONS_BASE_FREQUENCY)))
3737

38+
DEFAULT_LABEL_SMOOTHING = 0.2
39+
DEFAULT_POSITIVE_WEIGHT = 0.5
40+
41+
CONTOUR_KERNEL_SIZE_1 = (5, 5)
42+
CONTOUR_KERNEL_SIZE_2 = (3, 39) # 3*13
43+
CONTOUR_KERNEL_SIZE_3 = (5, 5)
44+
CONTOUR_FILTERS_2 = 8
45+
46+
NOTES_KERNEL_SIZE_1 = (7, 7)
47+
NOTES_STRIDES_1 = (1, 3)
48+
NOTES_KERNEL_SIZE_2 = (7, 3)
49+
50+
ONSET_KERNEL_SIZE_1 = (5, 5)
51+
ONSET_STRIDES_1 = (1, 3)
52+
ONSET_KERNEL_SIZE_2 = (3, 3)
53+
3854

3955
def transcription_loss(y_true: tf.Tensor, y_pred: tf.Tensor, label_smoothing: float) -> tf.Tensor:
4056
"""Really a binary cross entropy loss. Used to calculate the loss between the predicted
@@ -103,7 +119,7 @@ def onset_loss(
103119
return lambda x, y: transcription_loss(x, y, label_smoothing=label_smoothing)
104120

105121

106-
def loss(label_smoothing: float = 0.2, weighted: bool = False, positive_weight: float = 0.5) -> Dict[str, Any]:
122+
def loss(label_smoothing: float = DEFAULT_LABEL_SMOOTHING, weighted: bool = False, positive_weight: float = DEFAULT_POSITIVE_WEIGHT) -> Dict[str, Any]:
107123
"""Creates a keras-compatible dictionary of loss functions to calculate
108124
the loss for the contour, note and onset posteriorgrams.
109125
@@ -206,7 +222,7 @@ def model(
206222
# contour layers - fully convolutional
207223
x_contours = tfkl.Conv2D(
208224
n_filters_contour,
209-
(5, 5),
225+
CONTOUR_KERNEL_SIZE_1,
210226
padding="same",
211227
kernel_initializer=_initializer(),
212228
kernel_constraint=_kernel_constraint(),
@@ -216,8 +232,8 @@ def model(
216232
x_contours = tfkl.ReLU()(x_contours)
217233

218234
x_contours = tfkl.Conv2D(
219-
8,
220-
(3, 3 * 13),
235+
CONTOUR_FILTERS_2,
236+
CONTOUR_KERNEL_SIZE_2,
221237
padding="same",
222238
kernel_initializer=_initializer(),
223239
kernel_constraint=_kernel_constraint(),
@@ -230,7 +246,7 @@ def model(
230246
contour_name = "contour"
231247
x_contours = tfkl.Conv2D(
232248
1,
233-
(5, 5),
249+
CONTOUR_KERNEL_SIZE_3,
234250
padding="same",
235251
activation="sigmoid",
236252
kernel_initializer=_initializer(),
@@ -246,9 +262,9 @@ def model(
246262

247263
x_contours_reduced = tfkl.Conv2D(
248264
n_filters_notes,
249-
(7, 7),
265+
NOTES_KERNEL_SIZE_1,
250266
padding="same",
251-
strides=(1, 3),
267+
strides=NOTES_STRIDES_1,
252268
kernel_initializer=_initializer(),
253269
kernel_constraint=_kernel_constraint(),
254270
)(x_contours_reduced)
@@ -258,7 +274,7 @@ def model(
258274
note_name = "note"
259275
x_notes_pre = tfkl.Conv2D(
260276
1,
261-
(7, 3),
277+
NOTES_KERNEL_SIZE_2,
262278
padding="same",
263279
kernel_initializer=_initializer(),
264280
kernel_constraint=_kernel_constraint(),
@@ -271,9 +287,9 @@ def model(
271287
# onsets - fully convolutional
272288
x_onset = tfkl.Conv2D(
273289
n_filters_onsets,
274-
(5, 5),
290+
ONSET_KERNEL_SIZE_1,
275291
padding="same",
276-
strides=(1, 3),
292+
strides=ONSET_STRIDES_1,
277293
kernel_initializer=_initializer(),
278294
kernel_constraint=_kernel_constraint(),
279295
)(x)
@@ -282,7 +298,7 @@ def model(
282298
x_onset = tfkl.Concatenate(axis=3, name="concat")([x_notes_pre, x_onset])
283299
x_onset = tfkl.Conv2D(
284300
1,
285-
(3, 3),
301+
ONSET_KERNEL_SIZE_2,
286302
padding="same",
287303
activation="sigmoid",
288304
kernel_initializer=_initializer(),

basic_pitch/nn.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
from basic_pitch.layers.math import log_base_b
2424

25+
SEMITONES_PER_OCTAVE = 12
26+
2527

2628
class HarmonicStacking(tf.keras.layers.Layer):
2729
"""Harmonic stacking layer
@@ -47,7 +49,7 @@ def __init__(
4749
self.bins_per_semitone = bins_per_semitone
4850
self.harmonics = harmonics
4951
self.shifts = [
50-
int(tf.math.round(12.0 * self.bins_per_semitone * log_base_b(float(h), 2))) for h in self.harmonics
52+
int(tf.math.round(SEMITONES_PER_OCTAVE * self.bins_per_semitone * log_base_b(float(h), 2))) for h in self.harmonics
5153
]
5254
self.n_output_freqs = n_output_freqs
5355

@@ -96,7 +98,7 @@ def call(self, x: tf.Tensor) -> tf.Tensor:
9698
"""x: (batch, time, ch)"""
9799
shapes = K.int_shape(x)
98100
tf.assert_equal(shapes[2], 1)
99-
return tf.keras.layers.Reshape([shapes[1]])(x) # ignore batch size
101+
return tf.squeeze(x, axis=2)
100102

101103

102104
class FlattenFreqCh(tf.keras.layers.Layer):
@@ -109,4 +111,8 @@ class FlattenFreqCh(tf.keras.layers.Layer):
109111

110112
def call(self, x: tf.Tensor) -> tf.Tensor:
111113
shapes = K.int_shape(x)
112-
return tf.keras.layers.Reshape([shapes[1], shapes[2] * shapes[3]])(x) # ignore batch size
114+
batch_size = tf.shape(x)[0]
115+
time_dim = shapes[1]
116+
freq_dim = shapes[2]
117+
ch_dim = shapes[3]
118+
return tf.reshape(x, [batch_size, time_dim, freq_dim * ch_dim])

0 commit comments

Comments
 (0)