11__all__ = ["ReturnnDumpHDFJob" , "ReturnnRasrDumpHDFJob" , "BlissToPcmHDFJob" , "RasrAlignmentDumpHDFJob" ]
22
33from dataclasses import dataclass
4+ from enum import Enum , auto
45import glob
6+ import math
57import numpy as np
68import os
79import shutil
@@ -208,7 +210,11 @@ class PickNth(BaseStrategy):
208210 def __eq__ (self , other ):
209211 return super ().__eq__ (other ) and other .channel == self .channel
210212
211- __sis_hash_exclude__ = {"multi_channel_strategy" : BaseStrategy ()}
213+ class RoundingScheme (Enum ):
214+ start_and_duration = auto ()
215+ rasr_compatible = auto ()
216+
217+ __sis_hash_exclude__ = {"multi_channel_strategy" : BaseStrategy (), "rounding" : RoundingScheme .start_and_duration }
212218
213219 def __init__ (
214220 self ,
@@ -217,6 +223,7 @@ def __init__(
217223 output_dtype : str = "int16" ,
218224 multi_channel_strategy : BaseStrategy = BaseStrategy (),
219225 returnn_root : Optional [tk .Path ] = None ,
226+ rounding : RoundingScheme = RoundingScheme .start_and_duration ,
220227 ):
221228 """
222229
@@ -228,6 +235,9 @@ def __init__(
228235 BaseStrategy(): no handling, assume only one channel
229236 PickNth(n): Takes audio from n-th channel
230237 :param returnn_root: RETURNN repository
238+ :param rounding: defines how timestamps should be rounded if they do not exactly fall onto a sample:
239+ start_and_duration will round down the start time and the duration of the segment
240+ rasr_compatible will round up the start time and round down the end time
231241 """
232242 self .set_vis_name ("Dump audio to HDF" )
233243 assert output_dtype in ["float64" , "float32" , "int32" , "int16" ]
@@ -237,10 +247,12 @@ def __init__(
237247 self .output_dtype = output_dtype
238248 self .multi_channel_strategy = multi_channel_strategy
239249 self .returnn_root = returnn_root
240- self .rqmt = {}
250+ self .rounding = rounding
241251
242252 self .out_hdf = self .output_path ("audio.hdf" )
243253
254+ self .rqmt = {}
255+
244256 def tasks (self ):
245257 yield Task ("run" , rqmt = self .rqmt )
246258
@@ -265,9 +277,17 @@ def run(self):
265277
266278 for segment in recording .segments :
267279 if (not segments_whitelist ) or (segment .fullname () in segments_whitelist ):
268- audio .seek (int (segment .start * audio .samplerate ))
280+ if self .rounding == self .RoundingScheme .start_and_duration :
281+ start = int (segment .start * audio .samplerate )
282+ duration = int ((segment .end - segment .start ) * audio .samplerate )
283+ elif self .rounding == self .RoundingScheme .rasr_compatible :
284+ start = math .ceil (segment .start * audio .samplerate )
285+ duration = math .floor (segment .end * audio .samplerate ) - start
286+ else :
287+ raise NotImplementedError (f"RoundingScheme { self .rounding } not implemented." )
288+ audio .seek (start )
269289 data = audio .read (
270- int (( segment . end - segment . start ) * audio . samplerate ) ,
290+ duration ,
271291 always_2d = True ,
272292 dtype = self .output_dtype ,
273293 )
0 commit comments