Skip to content

Commit 4adf659

Browse files
committed
latency script more, wip
1 parent d1b7ce4 commit 4adf659

File tree

1 file changed

+30
-16
lines changed
  • users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts

1 file changed

+30
-16
lines changed

users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts/latency.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Deps:
2525
"""deps"""
2626

2727
phone_alignments: Union[FileArchiveBundle, FileArchive]
28+
phone_alignment_ms_per_frame: float
2829
lexicon: Lexicon
2930
labels_with_eoc_hdf: HDFDataset
3031
corpus: Dict[str, BlissItem]
@@ -316,35 +317,42 @@ def handle_segment(deps: Deps, segment_name: str):
316317
phone_alignment = deps.phone_alignments.read(segment_name, "align")
317318
corpus_entry = deps.corpus[segment_name]
318319
words = corpus_entry.orth.split()
319-
for word in words:
320-
lemma = deps.lexicon.orth_to_lemma[word]
321-
print(lemma)
322320
allophones = deps.phone_alignments.get_allophones_list()
323-
for time, index, state, weight in phone_alignment:
324-
allophone = allophones[index] # like: "[SILENCE]{#+#}@i@f" or "W{HH+AH}"
321+
next_time_idx = 0
322+
word_idx = 0
323+
cur_word_phones = []
324+
for time_idx, allophone_idx, state, weight in phone_alignment:
325+
assert next_time_idx == time_idx
326+
next_time_idx += 1
327+
allophone = allophones[allophone_idx] # like: "[SILENCE]{#+#}@i@f" or "W{HH+AH}"
325328
m = re.match(r"([a-zA-Z\[\]#]+){([a-zA-Z\[\]#]+)\+([a-zA-Z\[\]#]+)}(@i)?(@f)?", allophone)
326329
assert m
327330
center, left, right, is_initial, is_final = m.groups()
328331
if center in deps.lexicon.special_phones:
329332
lemma = deps.lexicon.special_phones[center]
330333
if "" in lemma.orth: # e.g. silence
331334
continue # skip silence or similar
332-
# Keep similar format as Sprint archiver.
333-
items = [
334-
f"time={time}",
335-
f"allophone={allophone}",
336-
f"index={index}",
337-
f"state={state}",
338-
]
339-
if weight != 1:
340-
items.append(f"weight={weight}")
341-
print("\t".join(items))
335+
if time_idx + 1 >= len(phone_alignment) or phone_alignment[time_idx + 1][1] == allophone_idx:
336+
continue # skip to the last frame for this phoneme
337+
cur_word_phones.append(center)
338+
339+
if is_final:
340+
lemma = deps.lexicon.orth_to_lemma[words[word_idx]]
341+
phones_s = " ".join(cur_word_phones)
342+
print(f"end time {time_idx * deps.phone_alignment_ms_per_frame / 1000.}sec:", lemma.orth[0], "/", phones_s)
343+
if phones_s not in lemma.phon:
344+
print(f"WARNING: phones {phones_s} not in lemma {lemma}?")
345+
346+
cur_word_phones.clear()
347+
word_idx += 1
348+
assert word_idx == len(words)
342349

343350

344351
def main():
345352
"""main"""
346353
arg_parser = argparse.ArgumentParser()
347354
arg_parser.add_argument("--phone-alignments", required=True)
355+
arg_parser.add_argument("--phone-alignment-ms-per-frame", type=float, default=10.0)
348356
arg_parser.add_argument("--allophone-file", required=True)
349357
arg_parser.add_argument("--lexicon", required=True)
350358
arg_parser.add_argument("--corpus", required=True)
@@ -365,7 +373,13 @@ def main():
365373
for item in iter_bliss(args.corpus):
366374
corpus[item.segment_name] = item
367375

368-
deps = Deps(phone_alignments=phone_alignments, lexicon=lexicon, labels_with_eoc_hdf=dataset, corpus=corpus)
376+
deps = Deps(
377+
phone_alignments=phone_alignments,
378+
phone_alignment_ms_per_frame=args.phone_alignment_ms_per_frame,
379+
lexicon=lexicon,
380+
labels_with_eoc_hdf=dataset,
381+
corpus=corpus,
382+
)
369383

370384
for segment_name in args.segment or corpus:
371385
print(corpus[segment_name])

0 commit comments

Comments
 (0)