|
4 | 4 |
|
5 | 5 | from __future__ import annotations |
6 | 6 | from dataclasses import dataclass |
7 | | -from typing import Optional, List |
| 7 | +from typing import Optional, Union, List, Dict |
8 | 8 | import argparse |
9 | | -import subprocess |
10 | | -import re |
11 | 9 | import gzip |
12 | 10 | from decimal import Decimal |
13 | 11 | from xml.etree import ElementTree |
14 | 12 | from collections import OrderedDict |
15 | 13 | from returnn.datasets.hdf import HDFDataset |
16 | | -from returnn.sprint.cache import WordBoundaries |
| 14 | +from returnn.sprint.cache import open_file_archive, FileArchiveBundle, FileArchive |
| 15 | +from returnn.util import better_exchook |
17 | 16 |
|
18 | 17 |
|
19 | 18 | ET = ElementTree |
|
23 | 22 | class Deps: |
24 | 23 | """deps""" |
25 | 24 |
|
26 | | - sprint_archiver_bin: str |
27 | | - sprint_phone_alignment: str |
28 | | - sprint_allophone_file: str |
29 | | - |
| 25 | + sprint_phone_alignments: Union[FileArchiveBundle, FileArchive] |
30 | 26 | sprint_lexicon: Lexicon |
31 | 27 | labels_with_eoc_hdf: HDFDataset |
| 28 | + corpus: Dict[str, BlissItem] |
32 | 29 |
|
33 | 30 |
|
34 | 31 | def uopen(path: str, *args, **kwargs): |
@@ -292,93 +289,60 @@ def from_element(cls, e): |
292 | 289 | return Lemma(orth, phon, synt, eval, special) |
293 | 290 |
|
294 | 291 |
|
295 | | -def get_sprint_allophone_seq(deps: Deps, segment_name: str) -> List[str]: |
296 | | - """sprint""" |
297 | | - cmd = [ |
298 | | - deps.sprint_archiver_bin, |
299 | | - deps.sprint_phone_alignment, |
300 | | - segment_name, |
301 | | - "--mode", |
302 | | - "show", |
303 | | - "--type", |
304 | | - "align", |
305 | | - "--allophone-file", |
306 | | - deps.sprint_allophone_file, |
307 | | - ] |
308 | | - # output looks like: |
309 | | - """ |
310 | | - <?xml version="1.0" encoding="ISO-8859-1"?> |
311 | | - <sprint> |
312 | | - time= 0 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0 |
313 | | - time= 1 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0 |
314 | | - time= 2 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0 |
315 | | - time= 3 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0 |
316 | | - time= 4 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0 |
317 | | - time= 5 emission= 18025 allophone= HH{#+W}@i index= 18025 state= 0 |
318 | | - time= 6 emission= 18025 allophone= HH{#+W}@i index= 18025 state= 0 |
319 | | - time= 7 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1 |
320 | | - time= 8 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1 |
321 | | - time= 9 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1 |
322 | | - time= 10 emission= 134235753 allophone= HH{#+W}@i index= 18025 state= 2 |
323 | | - time= 11 emission= 134235753 allophone= HH{#+W}@i index= 18025 state= 2 |
324 | | - ... |
325 | | - """ |
326 | | - out = subprocess.check_output(cmd) |
327 | | - time_idx = 0 |
328 | | - res = [] |
329 | | - for line in out.splitlines(): |
330 | | - line = line.strip() |
331 | | - if not line.startswith(b"time="): |
332 | | - continue |
333 | | - line = line.decode("utf8") |
334 | | - m = re.match( |
335 | | - r"time=\s*([0-9]+)\s+" |
336 | | - r"emission=\s*([0-9]+)\s+" |
337 | | - r"allophone=\s*(\S+)\s+" |
338 | | - r"index=\s*([0-9]+)\s+" |
339 | | - r"state=\s*([0-9]*)", |
340 | | - line, |
341 | | - ) |
342 | | - assert m, f"failed to parse line: {line}" |
343 | | - t, emission, allophone, index, state = m.groups() |
344 | | - assert int(t) == time_idx |
345 | | - res += [allophone] |
346 | | - time_idx += 1 |
347 | | - return res |
348 | | - |
349 | | - |
350 | 292 | def get_sprint_word_ends(deps: Deps, segment_name: str) -> List[int]: |
351 | 293 | pass |
352 | 294 |
|
353 | 295 |
|
354 | 296 | def handle_segment(deps: Deps, segment_name: str): |
355 | 297 | """handle segment""" |
356 | | - pass |
| 298 | + f = deps.sprint_phone_alignments.read(segment_name, "align") |
| 299 | + allophones = deps.sprint_phone_alignments.get_allophones_list() |
| 300 | + for time, index, state, weight in f: |
| 301 | + # Keep similar format as Sprint archiver. |
| 302 | + items = [ |
| 303 | + f"time={time}", |
| 304 | + f"allophone={allophones[index]}", |
| 305 | + f"index={index}", |
| 306 | + f"state={state}", |
| 307 | + ] |
| 308 | + if weight != 1: |
| 309 | + items.append(f"weight={weight}") |
| 310 | + print("\t".join(items)) |
357 | 311 |
|
358 | 312 |
|
359 | 313 | def main(): |
360 | 314 | """main""" |
361 | 315 | arg_parser = argparse.ArgumentParser() |
362 | | - arg_parser.add_argument("--archiver-bin", required=True) |
363 | | - arg_parser.add_argument("--phone-alignment", required=True) |
| 316 | + arg_parser.add_argument("--phone-alignments", required=True) |
364 | 317 | arg_parser.add_argument("--allophone-file", required=True) |
365 | 318 | arg_parser.add_argument("--lexicon", required=True) |
366 | 319 | arg_parser.add_argument("--corpus", required=True) |
| 320 | + arg_parser.add_argument("--labels-with-eoc", required=True) |
| 321 | + arg_parser.add_argument("--segment", nargs="*") |
367 | 322 | args = arg_parser.parse_args() |
368 | 323 |
|
| 324 | + phone_alignments = open_file_archive(args.phone_alignments) |
| 325 | + phone_alignments.set_allophones(args.allophone_file) |
| 326 | + |
| 327 | + lexicon = Lexicon(args.lexicon) |
| 328 | + |
| 329 | + dataset = HDFDataset([args.labels_with_eoc]) |
| 330 | + dataset.initialize() |
| 331 | + dataset.init_seq_order(epoch=1) |
| 332 | + |
| 333 | + corpus = {} |
| 334 | + for item in iter_bliss(args.corpus): |
| 335 | + corpus[item.segment_name] = item |
| 336 | + |
369 | 337 | deps = Deps( |
370 | | - sprint_archiver_bin=args.archiver_bin, |
371 | | - sprint_phone_alignment=args.phone_alignment, |
372 | | - sprint_allophone_file=args.allophone_file, |
373 | | - sprint_lexicon=Lexicon(args.lexicon), |
374 | | - labels_with_eoc_hdf=HDFDataset([args.corpus]), |
| 338 | + sprint_phone_alignments=phone_alignments, sprint_lexicon=lexicon, labels_with_eoc_hdf=dataset, corpus=corpus |
375 | 339 | ) |
376 | 340 |
|
377 | | - for item in iter_bliss(args.corpus): |
378 | | - print(item) |
379 | | - print(get_sprint_allophone_seq(deps, item.segment_name)) |
380 | | - break |
| 341 | + for segment_name in args.segment or corpus: |
| 342 | + print(corpus[segment_name]) |
| 343 | + handle_segment(deps, segment_name) |
381 | 344 |
|
382 | 345 |
|
383 | 346 | if __name__ == "__main__": |
| 347 | + better_exchook.install() |
384 | 348 | main() |
0 commit comments