@@ -181,6 +181,14 @@ def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float3
181181 }
182182
183183
184+ DEFAULT_ONSET_THRESHOLD = 0.5
185+ DEFAULT_FRAME_THRESHOLD = 0.3
186+ DEFAULT_MINIMUM_NOTE_LENGTH_MS = 127.7
187+ DEFAULT_MINIMUM_MIDI_TEMPO = 120
188+ DEFAULT_SONIFICATION_SAMPLERATE = 44100
189+ DEFAULT_OVERLAPPING_FRAMES = 30
190+ DEFAULT_MIDI_VELOCITY_SCALE = 127
191+
184192def window_audio_file (
185193 audio_original : npt .NDArray [np .float32 ], hop_size : int
186194) -> Iterable [Tuple [npt .NDArray [np .float32 ], Dict [str , float ]]]:
@@ -284,7 +292,7 @@ def run_inference(
284292 model = Model (model_or_model_path )
285293
286294 # overlap 30 frames
287- n_overlapping_frames = 30
295+ n_overlapping_frames = DEFAULT_OVERLAPPING_FRAMES
288296 overlap_len = n_overlapping_frames * FFT_HOP
289297 hop_size = AUDIO_N_SAMPLES - overlap_len
290298
@@ -405,7 +413,7 @@ def save_note_events(
405413 writer = csv .writer (fhandle , delimiter = "," )
406414 writer .writerow (["start_time_s" , "end_time_s" , "pitch_midi" , "velocity" , "pitch_bend" ])
407415 for start_time , end_time , note_number , amplitude , pitch_bend in note_events :
408- row = [start_time , end_time , note_number , int (np .round (127 * amplitude ))]
416+ row = [start_time , end_time , note_number , int (np .round (DEFAULT_MIDI_VELOCITY_SCALE * amplitude ))]
409417 if pitch_bend :
410418 row .extend (pitch_bend )
411419 writer .writerow (row )
@@ -414,15 +422,15 @@ def save_note_events(
414422def predict (
415423 audio_path : Union [pathlib .Path , str ],
416424 model_or_model_path : Union [Model , pathlib .Path , str ] = ICASSP_2022_MODEL_PATH ,
417- onset_threshold : float = 0.5 ,
418- frame_threshold : float = 0.3 ,
419- minimum_note_length : float = 127.70 ,
425+ onset_threshold : float = DEFAULT_ONSET_THRESHOLD ,
426+ frame_threshold : float = DEFAULT_FRAME_THRESHOLD ,
427+ minimum_note_length : float = DEFAULT_MINIMUM_NOTE_LENGTH_MS ,
420428 minimum_frequency : Optional [float ] = None ,
421429 maximum_frequency : Optional [float ] = None ,
422430 multiple_pitch_bends : bool = False ,
423431 melodia_trick : bool = True ,
424432 debug_file : Optional [pathlib .Path ] = None ,
425- midi_tempo : float = 120 ,
433+ midi_tempo : float = DEFAULT_MINIMUM_MIDI_TEMPO ,
426434) -> Tuple [
427435 Dict [str , np .array ],
428436 pretty_midi .PrettyMIDI ,
@@ -497,16 +505,16 @@ def predict_and_save(
497505 save_model_outputs : bool ,
498506 save_notes : bool ,
499507 model_or_model_path : Union [Model , str , pathlib .Path ],
500- onset_threshold : float = 0.5 ,
501- frame_threshold : float = 0.3 ,
502- minimum_note_length : float = 127.70 ,
508+ onset_threshold : float = DEFAULT_ONSET_THRESHOLD ,
509+ frame_threshold : float = DEFAULT_FRAME_THRESHOLD ,
510+ minimum_note_length : float = DEFAULT_MINIMUM_NOTE_LENGTH_MS ,
503511 minimum_frequency : Optional [float ] = None ,
504512 maximum_frequency : Optional [float ] = None ,
505513 multiple_pitch_bends : bool = False ,
506514 melodia_trick : bool = True ,
507515 debug_file : Optional [pathlib .Path ] = None ,
508- sonification_samplerate : int = 44100 ,
509- midi_tempo : float = 120 ,
516+ sonification_samplerate : int = DEFAULT_SONIFICATION_SAMPLERATE ,
517+ midi_tempo : float = DEFAULT_MINIMUM_MIDI_TEMPO ,
510518) -> None :
511519 """Make a prediction and save the results to file.
512520
0 commit comments