Skip to content

Commit d1b7ce4

Browse files
committed
latency script more, wip
1 parent eb824e3 commit d1b7ce4

File tree

1 file changed

+38
-9
lines changed
  • users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts

1 file changed

+38
-9
lines changed

users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts/latency.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
"""
44

55
from __future__ import annotations
6+
67
from dataclasses import dataclass
78
from typing import Optional, Union, List, Dict
89
import argparse
910
import gzip
11+
import re
1012
from decimal import Decimal
1113
from xml.etree import ElementTree
1214
from collections import OrderedDict
@@ -22,8 +24,8 @@
2224
class Deps:
2325
"""deps"""
2426

25-
sprint_phone_alignments: Union[FileArchiveBundle, FileArchive]
26-
sprint_lexicon: Lexicon
27+
phone_alignments: Union[FileArchiveBundle, FileArchive]
28+
lexicon: Lexicon
2729
labels_with_eoc_hdf: HDFDataset
2830
corpus: Dict[str, BlissItem]
2931

@@ -132,6 +134,8 @@ class Lexicon:
132134
def __init__(self, file: Optional[str] = None):
133135
self.phonemes = OrderedDict() # type: OrderedDict[str, str] # symbol => variation
134136
self.lemmata = [] # type: List[Lemma]
137+
self.orth_to_lemma = {} # type: Dict[str, Lemma]
138+
self.special_phones = {} # type: Dict[str, Lemma]
135139
if file:
136140
self.load(file)
137141

@@ -155,6 +159,11 @@ def add_lemma(self, lemma):
155159
"""
156160
assert isinstance(lemma, Lemma)
157161
self.lemmata.append(lemma)
162+
for orth in lemma.orth:
163+
self.orth_to_lemma[orth] = lemma
164+
if lemma.special:
165+
for phon in lemma.phon:
166+
self.special_phones[phon] = lemma
158167

159168
def load(self, path):
160169
"""
@@ -233,6 +242,15 @@ def __init__(
233242
"and can be safely changed into a single list"
234243
)
235244

245+
def __repr__(self):
246+
return "Lemma(orth=%r, phon=%r, synt=%r, eval=%r, special=%r)" % (
247+
self.orth,
248+
self.phon,
249+
self.synt,
250+
self.eval,
251+
self.special,
252+
)
253+
236254
def to_xml(self):
237255
"""
238256
:return: xml representation
@@ -295,13 +313,26 @@ def get_sprint_word_ends(deps: Deps, segment_name: str) -> List[int]:
295313

296314
def handle_segment(deps: Deps, segment_name: str):
297315
"""handle segment"""
298-
f = deps.sprint_phone_alignments.read(segment_name, "align")
299-
allophones = deps.sprint_phone_alignments.get_allophones_list()
300-
for time, index, state, weight in f:
316+
phone_alignment = deps.phone_alignments.read(segment_name, "align")
317+
corpus_entry = deps.corpus[segment_name]
318+
words = corpus_entry.orth.split()
319+
for word in words:
320+
lemma = deps.lexicon.orth_to_lemma[word]
321+
print(lemma)
322+
allophones = deps.phone_alignments.get_allophones_list()
323+
for time, index, state, weight in phone_alignment:
324+
allophone = allophones[index] # like: "[SILENCE]{#+#}@i@f" or "W{HH+AH}"
325+
m = re.match(r"([a-zA-Z\[\]#]+){([a-zA-Z\[\]#]+)\+([a-zA-Z\[\]#]+)}(@i)?(@f)?", allophone)
326+
assert m
327+
center, left, right, is_initial, is_final = m.groups()
328+
if center in deps.lexicon.special_phones:
329+
lemma = deps.lexicon.special_phones[center]
330+
if "" in lemma.orth: # e.g. silence
331+
continue # skip silence or similar
301332
# Keep similar format as Sprint archiver.
302333
items = [
303334
f"time={time}",
304-
f"allophone={allophones[index]}",
335+
f"allophone={allophone}",
305336
f"index={index}",
306337
f"state={state}",
307338
]
@@ -334,9 +365,7 @@ def main():
334365
for item in iter_bliss(args.corpus):
335366
corpus[item.segment_name] = item
336367

337-
deps = Deps(
338-
sprint_phone_alignments=phone_alignments, sprint_lexicon=lexicon, labels_with_eoc_hdf=dataset, corpus=corpus
339-
)
368+
deps = Deps(phone_alignments=phone_alignments, lexicon=lexicon, labels_with_eoc_hdf=dataset, corpus=corpus)
340369

341370
for segment_name in args.segment or corpus:
342371
print(corpus[segment_name])

0 commit comments

Comments
 (0)