Skip to content

Commit e989e40

Browse files
authored
Fix intermediate frames trimming to make sure final notes output is not truncated (#179)
1 parent 945a4cd commit e989e40

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

basic_pitch/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
SEMITONES_PER_OCTAVE = 12 # for frequency bin calculations
2424

2525
FFT_HOP = 256
26-
N_FFT = 8 * FFT_HOP
2726

2827
NOTES_BINS_PER_SEMITONE = 1
2928
CONTOURS_BINS_PER_SEMITONE = 3

basic_pitch/inference.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
AUDIO_N_SAMPLES,
5858
ANNOTATIONS_FPS,
5959
FFT_HOP,
60+
AUDIO_WINDOW_LENGTH,
6061
)
6162
from basic_pitch.commandline_printing import (
6263
generating_file_message,
@@ -247,13 +248,15 @@ def unwrap_output(
247248
output: npt.NDArray[np.float32],
248249
audio_original_length: int,
249250
n_overlapping_frames: int,
251+
hop_size: int,
250252
) -> np.array:
251253
"""Unwrap batched model predictions to a single matrix.
252254
253255
Args:
254256
output: array (n_batches, n_times_short, n_freqs)
255257
audio_original_length: length of original audio signal (in samples)
256258
n_overlapping_frames: number of overlapping frames in the output
259+
hop_size: size of the hop used when scanning the input audio
257260
258261
Returns:
259262
array (n_times, n_freqs)
@@ -266,10 +269,14 @@ def unwrap_output(
266269
# remove half of the overlapping frames from beginning and end
267270
output = output[:, n_olap:-n_olap, :]
268271

272+
# Concatenate the frames outputs (overlapping frames removed) into a single dimension
269273
output_shape = output.shape
270-
n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)))
271274
unwrapped_output = output.reshape(output_shape[0] * output_shape[1], output_shape[2])
272-
return unwrapped_output[:n_output_frames_original, :] # trim to original audio length
275+
276+
# trim to number of expected windows in output
277+
n_expected_windows = audio_original_length / hop_size
278+
n_frames_per_window = (AUDIO_WINDOW_LENGTH * ANNOTATIONS_FPS) - n_overlapping_frames
279+
return unwrapped_output[: int(n_expected_windows * n_frames_per_window), :]
273280

274281

275282
def run_inference(
@@ -303,7 +310,8 @@ def run_inference(
303310
output[k].append(v)
304311

305312
unwrapped_output = {
306-
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output
313+
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames, hop_size)
314+
for k in output
307315
}
308316

309317
if debug_file:
8.59 KB
Binary file not shown.

tests/test_inference.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626
import numpy as np
2727
import numpy.typing as npt
2828

29-
from basic_pitch import ICASSP_2022_MODEL_PATH, inference
29+
from basic_pitch import ICASSP_2022_MODEL_PATH, inference, note_creation
3030
from basic_pitch.constants import (
3131
AUDIO_SAMPLE_RATE,
3232
AUDIO_N_SAMPLES,
3333
ANNOTATIONS_N_SEMITONES,
3434
FFT_HOP,
35+
ANNOTATION_HOP,
3536
)
3637

3738
RESOURCES_PATH = pathlib.Path(__file__).parent / "resources"
@@ -55,6 +56,13 @@ def test_predict() -> None:
5556
assert all(note_pitch_max)
5657
assert isinstance(note_events, list)
5758

59+
# Check that model output has the expected length according to the last frame second computed downstream
60+
# (via model_frames_to_time) with to a few frames of tolerance
61+
audio_length_s = librosa.get_duration(filename=test_audio_path)
62+
n_model_output_frames = model_output["note"].shape[0]
63+
last_frame_s = note_creation.model_frames_to_time(n_model_output_frames)[-1]
64+
np.testing.assert_allclose(last_frame_s, audio_length_s, atol=2 * ANNOTATION_HOP)
65+
5866
expected_model_output = np.load(RESOURCES_PATH / "vocadito_10" / "model_output.npz", allow_pickle=True)[
5967
"arr_0"
6068
].item()

0 commit comments

Comments
 (0)