5757 AUDIO_N_SAMPLES ,
5858 ANNOTATIONS_FPS ,
5959 FFT_HOP ,
60+ AUDIO_WINDOW_LENGTH ,
6061)
6162from basic_pitch .commandline_printing import (
6263 generating_file_message ,
@@ -247,13 +248,15 @@ def unwrap_output(
247248 output : npt .NDArray [np .float32 ],
248249 audio_original_length : int ,
249250 n_overlapping_frames : int ,
251+ hop_size : int ,
250252) -> np .array :
251253 """Unwrap batched model predictions to a single matrix.
252254
253255 Args:
254256 output: array (n_batches, n_times_short, n_freqs)
255257 audio_original_length: length of original audio signal (in samples)
256258 n_overlapping_frames: number of overlapping frames in the output
259+ hop_size: size of the hop used when scanning the input audio
257260
258261 Returns:
259262 array (n_times, n_freqs)
@@ -266,10 +269,14 @@ def unwrap_output(
266269 # remove half of the overlapping frames from beginning and end
267270 output = output [:, n_olap :- n_olap , :]
268271
272+ # Concatenate the frames outputs (overlapping frames removed) into a single dimension
269273 output_shape = output .shape
270- n_output_frames_original = int (np .floor (audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE )))
271274 unwrapped_output = output .reshape (output_shape [0 ] * output_shape [1 ], output_shape [2 ])
272- return unwrapped_output [:n_output_frames_original , :] # trim to original audio length
275+
276+ # trim to number of expected windows in output
277+ n_expected_windows = audio_original_length / hop_size
278+ n_frames_per_window = (AUDIO_WINDOW_LENGTH * ANNOTATIONS_FPS ) - n_overlapping_frames
279+ return unwrapped_output [: int (n_expected_windows * n_frames_per_window ), :]
273280
274281
275282def run_inference (
@@ -303,7 +310,8 @@ def run_inference(
303310 output [k ].append (v )
304311
305312 unwrapped_output = {
306- k : unwrap_output (np .concatenate (output [k ]), audio_original_length , n_overlapping_frames ) for k in output
313+ k : unwrap_output (np .concatenate (output [k ]), audio_original_length , n_overlapping_frames , hop_size )
314+ for k in output
307315 }
308316
309317 if debug_file :
0 commit comments