Skip to content

Commit eb824e3

Browse files
committed
latency script, wip
1 parent 611e055 commit eb824e3

File tree

1 file changed

+39
-75
lines changed
  • users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts

1 file changed

+39
-75
lines changed

users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts/latency.py

Lines changed: 39 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44

55
from __future__ import annotations
66
from dataclasses import dataclass
7-
from typing import Optional, List
7+
from typing import Optional, Union, List, Dict
88
import argparse
9-
import subprocess
10-
import re
119
import gzip
1210
from decimal import Decimal
1311
from xml.etree import ElementTree
1412
from collections import OrderedDict
1513
from returnn.datasets.hdf import HDFDataset
16-
from returnn.sprint.cache import WordBoundaries
14+
from returnn.sprint.cache import open_file_archive, FileArchiveBundle, FileArchive
15+
from returnn.util import better_exchook
1716

1817

1918
ET = ElementTree
@@ -23,12 +22,10 @@
2322
class Deps:
2423
"""deps"""
2524

26-
sprint_archiver_bin: str
27-
sprint_phone_alignment: str
28-
sprint_allophone_file: str
29-
25+
sprint_phone_alignments: Union[FileArchiveBundle, FileArchive]
3026
sprint_lexicon: Lexicon
3127
labels_with_eoc_hdf: HDFDataset
28+
corpus: Dict[str, BlissItem]
3229

3330

3431
def uopen(path: str, *args, **kwargs):
@@ -292,93 +289,60 @@ def from_element(cls, e):
292289
return Lemma(orth, phon, synt, eval, special)
293290

294291

295-
def get_sprint_allophone_seq(deps: Deps, segment_name: str) -> List[str]:
296-
"""sprint"""
297-
cmd = [
298-
deps.sprint_archiver_bin,
299-
deps.sprint_phone_alignment,
300-
segment_name,
301-
"--mode",
302-
"show",
303-
"--type",
304-
"align",
305-
"--allophone-file",
306-
deps.sprint_allophone_file,
307-
]
308-
# output looks like:
309-
"""
310-
<?xml version="1.0" encoding="ISO-8859-1"?>
311-
<sprint>
312-
time= 0 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
313-
time= 1 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
314-
time= 2 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
315-
time= 3 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
316-
time= 4 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
317-
time= 5 emission= 18025 allophone= HH{#+W}@i index= 18025 state= 0
318-
time= 6 emission= 18025 allophone= HH{#+W}@i index= 18025 state= 0
319-
time= 7 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1
320-
time= 8 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1
321-
time= 9 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1
322-
time= 10 emission= 134235753 allophone= HH{#+W}@i index= 18025 state= 2
323-
time= 11 emission= 134235753 allophone= HH{#+W}@i index= 18025 state= 2
324-
...
325-
"""
326-
out = subprocess.check_output(cmd)
327-
time_idx = 0
328-
res = []
329-
for line in out.splitlines():
330-
line = line.strip()
331-
if not line.startswith(b"time="):
332-
continue
333-
line = line.decode("utf8")
334-
m = re.match(
335-
r"time=\s*([0-9]+)\s+"
336-
r"emission=\s*([0-9]+)\s+"
337-
r"allophone=\s*(\S+)\s+"
338-
r"index=\s*([0-9]+)\s+"
339-
r"state=\s*([0-9]*)",
340-
line,
341-
)
342-
assert m, f"failed to parse line: {line}"
343-
t, emission, allophone, index, state = m.groups()
344-
assert int(t) == time_idx
345-
res += [allophone]
346-
time_idx += 1
347-
return res
348-
349-
350292
def get_sprint_word_ends(deps: Deps, segment_name: str) -> List[int]:
351293
pass
352294

353295

354296
def handle_segment(deps: Deps, segment_name: str):
355297
"""handle segment"""
356-
pass
298+
f = deps.sprint_phone_alignments.read(segment_name, "align")
299+
allophones = deps.sprint_phone_alignments.get_allophones_list()
300+
for time, index, state, weight in f:
301+
# Keep similar format as Sprint archiver.
302+
items = [
303+
f"time={time}",
304+
f"allophone={allophones[index]}",
305+
f"index={index}",
306+
f"state={state}",
307+
]
308+
if weight != 1:
309+
items.append(f"weight={weight}")
310+
print("\t".join(items))
357311

358312

359313
def main():
360314
"""main"""
361315
arg_parser = argparse.ArgumentParser()
362-
arg_parser.add_argument("--archiver-bin", required=True)
363-
arg_parser.add_argument("--phone-alignment", required=True)
316+
arg_parser.add_argument("--phone-alignments", required=True)
364317
arg_parser.add_argument("--allophone-file", required=True)
365318
arg_parser.add_argument("--lexicon", required=True)
366319
arg_parser.add_argument("--corpus", required=True)
320+
arg_parser.add_argument("--labels-with-eoc", required=True)
321+
arg_parser.add_argument("--segment", nargs="*")
367322
args = arg_parser.parse_args()
368323

324+
phone_alignments = open_file_archive(args.phone_alignments)
325+
phone_alignments.set_allophones(args.allophone_file)
326+
327+
lexicon = Lexicon(args.lexicon)
328+
329+
dataset = HDFDataset([args.labels_with_eoc])
330+
dataset.initialize()
331+
dataset.init_seq_order(epoch=1)
332+
333+
corpus = {}
334+
for item in iter_bliss(args.corpus):
335+
corpus[item.segment_name] = item
336+
369337
deps = Deps(
370-
sprint_archiver_bin=args.archiver_bin,
371-
sprint_phone_alignment=args.phone_alignment,
372-
sprint_allophone_file=args.allophone_file,
373-
sprint_lexicon=Lexicon(args.lexicon),
374-
labels_with_eoc_hdf=HDFDataset([args.corpus]),
338+
sprint_phone_alignments=phone_alignments, sprint_lexicon=lexicon, labels_with_eoc_hdf=dataset, corpus=corpus
375339
)
376340

377-
for item in iter_bliss(args.corpus):
378-
print(item)
379-
print(get_sprint_allophone_seq(deps, item.segment_name))
380-
break
341+
for segment_name in args.segment or corpus:
342+
print(corpus[segment_name])
343+
handle_segment(deps, segment_name)
381344

382345

383346
if __name__ == "__main__":
347+
better_exchook.install()
384348
main()

0 commit comments

Comments
 (0)