Skip to content

Commit 43dfdef

Browse files
authored
Add support for encoding and sparse data in RasrAlignmentDumpHDFJob (#434)
* add handling of encoding * Add support for sparse alignments * Add filter_list_keep
1 parent df08050 commit 43dfdef

File tree

2 files changed

+53
-25
lines changed

2 files changed

+53
-25
lines changed

lib/rasr_cache.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import mmap
1414
import numpy
1515
import os
16-
import sys
1716
import typing
1817
import zlib
1918
from struct import pack, unpack
@@ -51,7 +50,8 @@ class FileArchive:
5150
start_recovery_tag = 0xAA55AA55
5251
end_recovery_tag = 0x55AA55AA
5352

54-
def __init__(self, filename, must_exists=False):
53+
def __init__(self, filename, must_exists=False, encoding="ascii"):
54+
self.encoding = encoding
5555

5656
self.ft = {} # type: typing.Dict[str,FileInfo]
5757
if os.path.exists(filename):
@@ -182,12 +182,12 @@ def read_v(self, typ, size):
182182
return res
183183

184184
# write routines
185-
def write_str(self, s):
185+
def write_str(self, s, enc="ascii"):
186186
"""
187187
:param str s:
188188
:rtype: int
189189
"""
190-
return self.f.write(pack("%ds" % len(s), s.encode("ascii")))
190+
return self.f.write(pack("%ds" % len(s.encode(enc)), s.encode(enc)))
191191

192192
def write_char(self, i):
193193
"""
@@ -256,7 +256,7 @@ def readFileInfoTable(self):
256256
return
257257
for i in range(count):
258258
str_len = self.read_u32()
259-
name = self.read_str(str_len)
259+
name = self.read_str(str_len, self.encoding)
260260
pos = self.read_u64()
261261
size = self.read_u32()
262262
comp = self.read_u32()
@@ -271,8 +271,8 @@ def writeFileInfoTable(self):
271271
self.write_u32(len(self.ft))
272272

273273
for fi in self.ft.values():
274-
self.write_u32(len(fi.name))
275-
self.write_str(fi.name)
274+
self.write_u32(len(fi.name.encode(self.encoding)))
275+
self.write_str(fi.name, self.encoding)
276276
self.write_u64(fi.pos)
277277
self.write_u32(fi.size)
278278
self.write_u32(fi.compressed)
@@ -293,7 +293,7 @@ def scanArchive(self):
293293
continue
294294

295295
fn_len = self.read_u32()
296-
name = self.read_str(fn_len)
296+
name = self.read_str(fn_len, self.encoding)
297297
pos = self.f.tell()
298298
size = self.read_u32()
299299
comp = self.read_u32()
@@ -322,7 +322,7 @@ def _raw_read(self, size, typ):
322322
"""
323323

324324
if typ == "str":
325-
return self.read_str(size)
325+
return self.read_str(size, self.encoding)
326326

327327
elif typ == "feat":
328328
type_len = self.read_U32()
@@ -496,8 +496,8 @@ def addFeatureCache(self, filename, features, times):
496496
:param times:
497497
"""
498498
self.write_U32(self.start_recovery_tag)
499-
self.write_u32(len(filename))
500-
self.write_str(filename)
499+
self.write_u32(len(filename.encode(self.encoding)))
500+
self.write_str(filename, self.encoding)
501501
pos = self.f.tell()
502502
if len(features) > 0:
503503
dim = len(features[0])
@@ -542,8 +542,8 @@ def addAttributes(self, filename, dim, duration):
542542
) % (dim, duration)
543543
self.write_U32(self.start_recovery_tag)
544544
filename = "%s.attribs" % filename
545-
self.write_u32(len(filename))
546-
self.write_str(filename)
545+
self.write_u32(len(filename.encode(self.encoding)))
546+
self.write_str(filename, self.encoding)
547547
pos = self.f.tell()
548548
size = len(data)
549549
self.write_u32(size)
@@ -559,17 +559,18 @@ class FileArchiveBundle:
559559
File archive bundle.
560560
"""
561561

562-
def __init__(self, filename):
562+
def __init__(self, filename, encoding="ascii"):
563563
"""
564564
:param str filename: .bundle file
565+
:param str encoding: encoding used in the files
565566
"""
566567
# filename -> FileArchive
567568
self.archives = {} # type: typing.Dict[str,FileArchive]
568569
# archive content file -> FileArchive
569570
self.files = {} # type: typing.Dict[str,FileArchive]
570571
self._short_seg_names = {}
571572
for line in open(filename).read().splitlines():
572-
self.archives[line] = a = FileArchive(line, must_exists=True)
573+
self.archives[line] = a = FileArchive(line, must_exists=True, encoding=encoding)
573574
for f in a.ft.keys():
574575
self.files[f] = a
575576
# noinspection PyProtectedMember
@@ -616,17 +617,18 @@ def setAllophones(self, filename):
616617
a.setAllophones(filename)
617618

618619

619-
def open_file_archive(archive_filename, must_exists=True):
620+
def open_file_archive(archive_filename, must_exists=True, encoding="ascii"):
620621
"""
621622
:param str archive_filename:
622623
:param bool must_exists:
624+
:param str encoding:
623625
:rtype: FileArchiveBundle|FileArchive
624626
"""
625627
if archive_filename.endswith(".bundle"):
626628
assert must_exists
627-
return FileArchiveBundle(archive_filename)
629+
return FileArchiveBundle(archive_filename, encoding=encoding)
628630
else:
629-
return FileArchive(archive_filename, must_exists=must_exists)
631+
return FileArchive(archive_filename, must_exists=must_exists, encoding=encoding)
630632

631633

632634
def is_rasr_cache_file(filename):

returnn/hdf.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -311,28 +311,41 @@ class RasrAlignmentDumpHDFJob(Job):
311311
This Job reads Rasr alignment caches and dump them in hdf files.
312312
"""
313313

314+
__sis_hash_exclude__ = {"encoding": "ascii", "filter_list_keep": None, "sparse": False}
315+
314316
def __init__(
315317
self,
316318
alignment_caches: List[tk.Path],
317319
allophone_file: tk.Path,
318320
state_tying_file: tk.Path,
319321
data_type: type = np.uint16,
320322
returnn_root: Optional[tk.Path] = None,
323+
encoding: str = "ascii",
324+
filter_list_keep: Optional[tk.Path] = None,
325+
sparse: bool = False,
321326
):
322327
"""
323328
:param alignment_caches: e.g. output of an AlignmentJob
324329
:param allophone_file: e.g. output of a StoreAllophonesJob
325330
:param state_tying_file: e.g. output of a DumpStateTyingJob
326331
:param data_type: type that is used to store the data
327332
:param returnn_root: file path to the RETURNN repository root folder
333+
:param encoding: encoding of the segment names in the cache
334+
:param filter_list_keep: list of segment names to dump
335+
:param sparse: writes the data to hdf in sparse format
328336
"""
329337
self.alignment_caches = alignment_caches
330338
self.allophone_file = allophone_file
331339
self.state_tying_file = state_tying_file
340+
self.data_type = data_type
341+
self.returnn_root = returnn_root
342+
self.encoding = encoding
343+
self.filter_list_keep = filter_list_keep
344+
self.sparse = sparse
345+
332346
self.out_hdf_files = [self.output_path(f"data.hdf.{d}") for d in range(len(alignment_caches))]
333347
self.out_excluded_segments = self.output_path(f"excluded.segments")
334-
self.returnn_root = returnn_root
335-
self.data_type = data_type
348+
336349
self.rqmt = {"cpu": 1, "mem": 8, "time": 0.5}
337350

338351
def tasks(self):
@@ -354,22 +367,35 @@ def run(self, task_id):
354367
state_tying = dict(
355368
(k, int(v)) for l in open(self.state_tying_file.get_path()) for k, v in [l.strip().split()[0:2]]
356369
)
370+
num_classes = max(state_tying.values()) + 1
357371

358-
alignment_cache = FileArchive(self.alignment_caches[task_id - 1].get_path())
372+
alignment_cache = FileArchive(self.alignment_caches[task_id - 1].get_path(), encoding=self.encoding)
359373
alignment_cache.setAllophones(self.allophone_file.get_path())
374+
if self.filter_list_keep is not None:
375+
keep_segments = set(open(self.filter_list_keep.get_path()).read().splitlines())
376+
else:
377+
keep_segments = None
360378

361379
returnn_root = None if self.returnn_root is None else self.returnn_root.get_path()
362380
SimpleHDFWriter = get_returnn_simple_hdf_writer(returnn_root)
363-
out_hdf = SimpleHDFWriter(filename=self.out_hdf_files[task_id - 1], dim=1)
381+
out_hdf = SimpleHDFWriter(
382+
filename=self.out_hdf_files[task_id - 1],
383+
dim=num_classes if self.sparse else 1,
384+
ndim=1 if self.sparse else 2,
385+
)
364386

365387
excluded_segments = []
366388

367389
for file in alignment_cache.ft:
368390
info = alignment_cache.ft[file]
369-
if info.name.endswith(".attribs"):
370-
continue
371391
seq_name = info.name
372392

393+
if seq_name.endswith(".attribs"):
394+
continue
395+
if keep_segments is not None and seq_name not in keep_segments:
396+
excluded_segments.append(seq_name)
397+
continue
398+
373399
# alignment
374400
targets = []
375401
alignment = alignment_cache.read(file, "align")
@@ -382,7 +408,7 @@ def run(self, task_id):
382408

383409
data = np.array(targets).astype(np.dtype(self.data_type))
384410
out_hdf.insert_batch(
385-
inputs=data.reshape(1, -1, 1),
411+
inputs=data.reshape(1, -1) if self.sparse else data.reshape(1, -1, 1),
386412
seq_len=[data.shape[0]],
387413
seq_tag=[seq_name],
388414
)

0 commit comments

Comments
 (0)