Skip to content

Commit 611e055

Browse files
committed
latency script, wip
1 parent 8e7f015 commit 611e055

File tree

2 files changed

+384
-0
lines changed

2 files changed

+384
-0
lines changed

users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts/__init__.py

Whitespace-only changes.
Lines changed: 384 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
"""
2+
Calc latency
3+
"""
4+
5+
from __future__ import annotations
6+
from dataclasses import dataclass
7+
from typing import Optional, List
8+
import argparse
9+
import subprocess
10+
import re
11+
import gzip
12+
from decimal import Decimal
13+
from xml.etree import ElementTree
14+
from collections import OrderedDict
15+
from returnn.datasets.hdf import HDFDataset
16+
from returnn.sprint.cache import WordBoundaries
17+
18+
19+
ET = ElementTree
20+
21+
22+
@dataclass
23+
class Deps:
24+
"""deps"""
25+
26+
sprint_archiver_bin: str
27+
sprint_phone_alignment: str
28+
sprint_allophone_file: str
29+
30+
sprint_lexicon: Lexicon
31+
labels_with_eoc_hdf: HDFDataset
32+
33+
34+
def uopen(path: str, *args, **kwargs):
35+
if path.endswith(".gz"):
36+
return gzip.open(path, *args, **kwargs)
37+
else:
38+
return open(path, *args, **kwargs)
39+
40+
41+
class BlissItem:
42+
"""
43+
Bliss item.
44+
"""
45+
46+
def __init__(self, segment_name, recording_filename, start_time, end_time, orth, speaker_name=None):
47+
"""
48+
:param str segment_name:
49+
:param str recording_filename:
50+
:param Decimal start_time:
51+
:param Decimal end_time:
52+
:param str orth:
53+
:param str|None speaker_name:
54+
"""
55+
self.segment_name = segment_name
56+
self.recording_filename = recording_filename
57+
self.start_time = start_time
58+
self.end_time = end_time
59+
self.orth = orth
60+
self.speaker_name = speaker_name
61+
62+
def __repr__(self):
63+
keys = ["segment_name", "recording_filename", "start_time", "end_time", "orth", "speaker_name"]
64+
return "BlissItem(%s)" % ", ".join(["%s=%r" % (key, getattr(self, key)) for key in keys])
65+
66+
@property
67+
def delta_time(self):
68+
"""
69+
:rtype: float
70+
"""
71+
return self.end_time - self.start_time
72+
73+
74+
def iter_bliss(filename):
75+
"""
76+
:param str filename:
77+
:return: yields BlissItem
78+
:rtype: list[BlissItem]
79+
"""
80+
corpus_file = open(filename, "rb")
81+
if filename.endswith(".gz"):
82+
corpus_file = gzip.GzipFile(fileobj=corpus_file)
83+
84+
parser = ElementTree.XMLParser(target=ElementTree.TreeBuilder(), encoding="utf-8")
85+
context = iter(ElementTree.iterparse(corpus_file, parser=parser, events=("start", "end")))
86+
_, root = next(context) # get root element
87+
name_tree = [root.attrib["name"]]
88+
elem_tree = [root]
89+
count_tree = [0]
90+
recording_filename = None
91+
for event, elem in context:
92+
if elem.tag == "recording":
93+
recording_filename = elem.attrib["audio"] if event == "start" else None
94+
if event == "end" and elem.tag == "segment":
95+
elem_orth = elem.find("orth")
96+
orth_raw = elem_orth.text or "" # should be unicode
97+
orth_split = orth_raw.split()
98+
orth = " ".join(orth_split)
99+
elem_speaker = elem.find("speaker")
100+
if elem_speaker is not None:
101+
speaker_name = elem_speaker.attrib["name"]
102+
else:
103+
speaker_name = None
104+
segment_name = "/".join(name_tree)
105+
yield BlissItem(
106+
segment_name=segment_name,
107+
recording_filename=recording_filename,
108+
start_time=Decimal(elem.attrib["start"]),
109+
end_time=Decimal(elem.attrib["end"]),
110+
orth=orth,
111+
speaker_name=speaker_name,
112+
)
113+
root.clear() # free memory
114+
if event == "start":
115+
count_tree[-1] += 1
116+
count_tree.append(0)
117+
elem_tree += [elem]
118+
elem_name = elem.attrib.get("name", None)
119+
if elem_name is None:
120+
elem_name = str(count_tree[-2])
121+
assert isinstance(elem_name, str)
122+
name_tree += [elem_name]
123+
elif event == "end":
124+
assert elem_tree[-1] is elem
125+
elem_tree = elem_tree[:-1]
126+
name_tree = name_tree[:-1]
127+
count_tree = count_tree[:-1]
128+
129+
130+
class Lexicon:
131+
"""
132+
Represents a bliss lexicon, can be read from and written to .xml files
133+
"""
134+
135+
def __init__(self, file: Optional[str] = None):
136+
self.phonemes = OrderedDict() # type: OrderedDict[str, str] # symbol => variation
137+
self.lemmata = [] # type: List[Lemma]
138+
if file:
139+
self.load(file)
140+
141+
def add_phoneme(self, symbol, variation="context"):
142+
"""
143+
:param str symbol: representation of one phoneme
144+
:param str variation: possible values: "context" or "none".
145+
Use none for context independent phonemes like silence and noise.
146+
"""
147+
self.phonemes[symbol] = variation
148+
149+
def remove_phoneme(self, symbol):
150+
"""
151+
:param str symbol:
152+
"""
153+
del self.phonemes[symbol]
154+
155+
def add_lemma(self, lemma):
156+
"""
157+
:param Lemma lemma:
158+
"""
159+
assert isinstance(lemma, Lemma)
160+
self.lemmata.append(lemma)
161+
162+
def load(self, path):
163+
"""
164+
:param str path: bliss lexicon .xml or .xml.gz file
165+
"""
166+
with uopen(path, "rt") as f:
167+
root = ET.parse(f)
168+
169+
for phoneme in root.findall(".//phoneme-inventory/phoneme"):
170+
symbol = phoneme.find(".//symbol").text.strip()
171+
variation_element = phoneme.find(".//variation")
172+
variation = "context"
173+
if variation_element is not None:
174+
variation = variation_element.text.strip()
175+
self.add_phoneme(symbol, variation)
176+
177+
for lemma in root.findall(".//lemma"):
178+
l = Lemma.from_element(lemma)
179+
self.add_lemma(l)
180+
181+
def to_xml(self):
182+
"""
183+
:return: xml representation, can be used with `util.write_xml`
184+
:rtype: ET.Element
185+
"""
186+
root = ET.Element("lexicon")
187+
188+
pi = ET.SubElement(root, "phoneme-inventory")
189+
for symbol, variation in self.phonemes.items():
190+
p = ET.SubElement(pi, "phoneme")
191+
s = ET.SubElement(p, "symbol")
192+
s.text = symbol
193+
v = ET.SubElement(p, "variation")
194+
v.text = variation
195+
196+
for l in self.lemmata:
197+
root.append(l.to_xml())
198+
199+
return root
200+
201+
202+
class Lemma:
203+
"""
204+
Represents a lemma of a lexicon
205+
"""
206+
207+
def __init__(
208+
self,
209+
orth: Optional[List[str]] = None,
210+
phon: Optional[List[str]] = None,
211+
synt: Optional[List[str]] = None,
212+
eval: Optional[List[List[str]]] = None,
213+
special: Optional[str] = None,
214+
):
215+
"""
216+
:param orth: list of spellings used in the training data
217+
:param phon: list of pronunciation variants. Each str should
218+
contain a space separated string of phonemes from the phoneme-inventory.
219+
:param synt: list of LM tokens that form a single token sequence.
220+
This sequence is used as the language model representation.
221+
:param eval: list of output representations. Each
222+
sublist should contain one possible transcription (token sequence) of this lemma
223+
that is scored against the reference transcription.
224+
:param special: assigns special property to a lemma.
225+
Supported values: "silence", "unknown", "sentence-boundary",
226+
or "sentence-begin" / "sentence-end"
227+
"""
228+
self.orth = [] if orth is None else orth
229+
self.phon = [] if phon is None else phon
230+
self.synt = synt
231+
self.eval = [] if eval is None else eval
232+
self.special = special
233+
if isinstance(synt, list):
234+
assert not (len(synt) > 0 and isinstance(synt[0], list)), (
235+
"providing list of list is no longer supported for the 'synt' parameter "
236+
"and can be safely changed into a single list"
237+
)
238+
239+
def to_xml(self):
240+
"""
241+
:return: xml representation
242+
:rtype: ET.Element
243+
"""
244+
attrib = {"special": self.special} if self.special is not None else {}
245+
res = ET.Element("lemma", attrib=attrib)
246+
for o in self.orth:
247+
el = ET.SubElement(res, "orth")
248+
el.text = o
249+
for p in self.phon:
250+
el = ET.SubElement(res, "phon")
251+
el.text = p
252+
if self.synt is not None:
253+
el = ET.SubElement(res, "synt")
254+
for token in self.synt:
255+
el2 = ET.SubElement(el, "tok")
256+
el2.text = token
257+
for e in self.eval:
258+
el = ET.SubElement(res, "eval")
259+
for t in e:
260+
el2 = ET.SubElement(el, "tok")
261+
el2.text = t
262+
return res
263+
264+
@classmethod
265+
def from_element(cls, e):
266+
"""
267+
:param ET.Element e:
268+
:rtype: Lemma
269+
"""
270+
orth = []
271+
phon = []
272+
synt = []
273+
eval = []
274+
special = None
275+
if "special" in e.attrib:
276+
special = e.attrib["special"]
277+
for orth_element in e.findall(".//orth"):
278+
orth.append(orth_element.text.strip() if orth_element.text is not None else "")
279+
for phon_element in e.findall(".//phon"):
280+
phon.append(phon_element.text.strip() if phon_element.text is not None else "")
281+
for synt_element in e.findall(".//synt"):
282+
tokens = []
283+
for token_element in synt_element.findall(".//tok"):
284+
tokens.append(token_element.text.strip() if token_element.text is not None else "")
285+
synt.append(tokens)
286+
for eval_element in e.findall(".//eval"):
287+
tokens = []
288+
for token_element in eval_element.findall(".//tok"):
289+
tokens.append(token_element.text.strip() if token_element.text is not None else "")
290+
eval.append(tokens)
291+
synt = None if not synt else synt[0]
292+
return Lemma(orth, phon, synt, eval, special)
293+
294+
295+
def get_sprint_allophone_seq(deps: Deps, segment_name: str) -> List[str]:
296+
"""sprint"""
297+
cmd = [
298+
deps.sprint_archiver_bin,
299+
deps.sprint_phone_alignment,
300+
segment_name,
301+
"--mode",
302+
"show",
303+
"--type",
304+
"align",
305+
"--allophone-file",
306+
deps.sprint_allophone_file,
307+
]
308+
# output looks like:
309+
"""
310+
<?xml version="1.0" encoding="ISO-8859-1"?>
311+
<sprint>
312+
time= 0 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
313+
time= 1 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
314+
time= 2 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
315+
time= 3 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
316+
time= 4 emission= 115 allophone= [SILENCE]{#+#}@i@f index= 115 state= 0
317+
time= 5 emission= 18025 allophone= HH{#+W}@i index= 18025 state= 0
318+
time= 6 emission= 18025 allophone= HH{#+W}@i index= 18025 state= 0
319+
time= 7 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1
320+
time= 8 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1
321+
time= 9 emission= 67126889 allophone= HH{#+W}@i index= 18025 state= 1
322+
time= 10 emission= 134235753 allophone= HH{#+W}@i index= 18025 state= 2
323+
time= 11 emission= 134235753 allophone= HH{#+W}@i index= 18025 state= 2
324+
...
325+
"""
326+
out = subprocess.check_output(cmd)
327+
time_idx = 0
328+
res = []
329+
for line in out.splitlines():
330+
line = line.strip()
331+
if not line.startswith(b"time="):
332+
continue
333+
line = line.decode("utf8")
334+
m = re.match(
335+
r"time=\s*([0-9]+)\s+"
336+
r"emission=\s*([0-9]+)\s+"
337+
r"allophone=\s*(\S+)\s+"
338+
r"index=\s*([0-9]+)\s+"
339+
r"state=\s*([0-9]*)",
340+
line,
341+
)
342+
assert m, f"failed to parse line: {line}"
343+
t, emission, allophone, index, state = m.groups()
344+
assert int(t) == time_idx
345+
res += [allophone]
346+
time_idx += 1
347+
return res
348+
349+
350+
def get_sprint_word_ends(deps: Deps, segment_name: str) -> List[int]:
351+
pass
352+
353+
354+
def handle_segment(deps: Deps, segment_name: str):
355+
"""handle segment"""
356+
pass
357+
358+
359+
def main():
360+
"""main"""
361+
arg_parser = argparse.ArgumentParser()
362+
arg_parser.add_argument("--archiver-bin", required=True)
363+
arg_parser.add_argument("--phone-alignment", required=True)
364+
arg_parser.add_argument("--allophone-file", required=True)
365+
arg_parser.add_argument("--lexicon", required=True)
366+
arg_parser.add_argument("--corpus", required=True)
367+
args = arg_parser.parse_args()
368+
369+
deps = Deps(
370+
sprint_archiver_bin=args.archiver_bin,
371+
sprint_phone_alignment=args.phone_alignment,
372+
sprint_allophone_file=args.allophone_file,
373+
sprint_lexicon=Lexicon(args.lexicon),
374+
labels_with_eoc_hdf=HDFDataset([args.corpus]),
375+
)
376+
377+
for item in iter_bliss(args.corpus):
378+
print(item)
379+
print(get_sprint_allophone_seq(deps, item.segment_name))
380+
break
381+
382+
383+
if __name__ == "__main__":
384+
main()

0 commit comments

Comments
 (0)