Skip to content

Commit ea4e491

Browse files
michelwivieting
andauthored
Add rounding scheme rasr_compatible (#433)
* Add rounding scheme rasr_compatible * Update returnn/hdf.py Co-authored-by: vieting <[email protected]> --------- Co-authored-by: vieting <[email protected]>
1 parent 4741cc8 commit ea4e491

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

returnn/hdf.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
__all__ = ["ReturnnDumpHDFJob", "ReturnnRasrDumpHDFJob", "BlissToPcmHDFJob", "RasrAlignmentDumpHDFJob"]
22

33
from dataclasses import dataclass
4+
from enum import Enum, auto
45
import glob
6+
import math
57
import numpy as np
68
import os
79
import shutil
@@ -208,7 +210,11 @@ class PickNth(BaseStrategy):
208210
def __eq__(self, other):
209211
return super().__eq__(other) and other.channel == self.channel
210212

211-
__sis_hash_exclude__ = {"multi_channel_strategy": BaseStrategy()}
213+
class RoundingScheme(Enum):
214+
start_and_duration = auto()
215+
rasr_compatible = auto()
216+
217+
__sis_hash_exclude__ = {"multi_channel_strategy": BaseStrategy(), "rounding": RoundingScheme.start_and_duration}
212218

213219
def __init__(
214220
self,
@@ -217,6 +223,7 @@ def __init__(
217223
output_dtype: str = "int16",
218224
multi_channel_strategy: BaseStrategy = BaseStrategy(),
219225
returnn_root: Optional[tk.Path] = None,
226+
rounding: RoundingScheme = RoundingScheme.start_and_duration,
220227
):
221228
"""
222229
@@ -228,6 +235,9 @@ def __init__(
228235
BaseStrategy(): no handling, assume only one channel
229236
PickNth(n): Takes audio from n-th channel
230237
:param returnn_root: RETURNN repository
238+
:param rounding: defines how timestamps should be rounded if they do not exactly fall onto a sample:
239+
start_and_duration will round down the start time and the duration of the segment
240+
rasr_compatible will round up the start time and round down the end time
231241
"""
232242
self.set_vis_name("Dump audio to HDF")
233243
assert output_dtype in ["float64", "float32", "int32", "int16"]
@@ -237,10 +247,12 @@ def __init__(
237247
self.output_dtype = output_dtype
238248
self.multi_channel_strategy = multi_channel_strategy
239249
self.returnn_root = returnn_root
240-
self.rqmt = {}
250+
self.rounding = rounding
241251

242252
self.out_hdf = self.output_path("audio.hdf")
243253

254+
self.rqmt = {}
255+
244256
def tasks(self):
245257
yield Task("run", rqmt=self.rqmt)
246258

@@ -265,9 +277,17 @@ def run(self):
265277

266278
for segment in recording.segments:
267279
if (not segments_whitelist) or (segment.fullname() in segments_whitelist):
268-
audio.seek(int(segment.start * audio.samplerate))
280+
if self.rounding == self.RoundingScheme.start_and_duration:
281+
start = int(segment.start * audio.samplerate)
282+
duration = int((segment.end - segment.start) * audio.samplerate)
283+
elif self.rounding == self.RoundingScheme.rasr_compatible:
284+
start = math.ceil(segment.start * audio.samplerate)
285+
duration = math.floor(segment.end * audio.samplerate) - start
286+
else:
287+
raise NotImplementedError(f"RoundingScheme {self.rounding} not implemented.")
288+
audio.seek(start)
269289
data = audio.read(
270-
int((segment.end - segment.start) * audio.samplerate),
290+
duration,
271291
always_2d=True,
272292
dtype=self.output_dtype,
273293
)

0 commit comments

Comments
 (0)