Skip to content

Commit 9e9124b

Browse files
committed
latency script more, mostly finished
1 parent fe244b9 commit 9e9124b

File tree

1 file changed

+99
-15
lines changed
  • users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts

1 file changed

+99
-15
lines changed

users/zeyer/experiments/exp2023_02_16_chunked_attention/scripts/latency.py

Lines changed: 99 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@ class Deps:
2626
"""deps"""
2727

2828
phone_alignments: Union[FileArchiveBundle, FileArchive]
29-
phone_alignment_ms_per_frame: float
29+
phone_alignment_sec_per_frame: Decimal
3030
lexicon: Lexicon
31-
labels_with_eoc_hdf: HDFDataset
3231
corpus: Dict[str, BlissItem]
33-
bpe_vocab: Vocabulary
32+
chunk_labels: HDFDataset
33+
eoc_idx: int
34+
chunk_bpe_vocab: Vocabulary
35+
chunk_left_padding: Decimal
36+
chunk_stride: Decimal
37+
chunk_size: Decimal
3438

3539

3640
def uopen(path: str, *args, **kwargs):
@@ -310,11 +314,21 @@ def from_element(cls, e):
310314
return Lemma(orth, phon, synt, eval, special)
311315

312316

313-
def get_sprint_word_ends(deps: Deps, segment_name: str) -> List[int]:
314-
pass
317+
def handle_segment(deps: Deps, segment_name: str) -> List[Decimal]:
318+
"""handle segment"""
319+
corpus_entry = deps.corpus[segment_name]
320+
words = corpus_entry.orth.split()
321+
phone_align_ends = get_phone_alignment_word_ends(deps, segment_name)
322+
chunk_ends = get_chunk_ends(deps, segment_name)
323+
assert len(phone_align_ends) == len(chunk_ends) == len(words)
324+
res = []
325+
for word, phone_align_end, chunk_end in zip(words, phone_align_ends, chunk_ends):
326+
print(f"{word}: {phone_align_end} vs {chunk_end}, latency: {chunk_end - phone_align_end}sec")
327+
res.append(chunk_end - phone_align_end)
328+
return res
315329

316330

317-
def handle_segment(deps: Deps, segment_name: str):
331+
def get_phone_alignment_word_ends(deps: Deps, segment_name: str) -> List[Decimal]:
318332
"""handle segment"""
319333
phone_alignment = deps.phone_alignments.read(segment_name, "align")
320334
corpus_entry = deps.corpus[segment_name]
@@ -323,6 +337,7 @@ def handle_segment(deps: Deps, segment_name: str):
323337
next_time_idx = 0
324338
word_idx = 0
325339
cur_word_phones = []
340+
res = []
326341
for time_idx, allophone_idx, state, weight in phone_alignment:
327342
assert next_time_idx == time_idx
328343
next_time_idx += 1
@@ -341,50 +356,119 @@ def handle_segment(deps: Deps, segment_name: str):
341356
if is_final:
342357
lemma = deps.lexicon.orth_to_lemma[words[word_idx]]
343358
phones_s = " ".join(cur_word_phones)
344-
print(f"end time {time_idx * deps.phone_alignment_ms_per_frame / 1000.}sec:", lemma.orth[0], "/", phones_s)
359+
print(f"end time {time_idx * deps.phone_alignment_sec_per_frame}sec:", lemma.orth[0], "/", phones_s)
345360
if phones_s not in lemma.phon:
346361
raise Exception(f"Phones {phones_s} not in lemma {lemma}?")
362+
res.append(time_idx * deps.phone_alignment_sec_per_frame)
347363

348364
cur_word_phones.clear()
349365
word_idx += 1
350366
assert word_idx == len(words)
367+
return res
368+
369+
370+
def get_chunk_ends(deps: Deps, segment_name: str) -> List[Decimal]:
371+
"""
372+
Example:
373+
374+
chunk_size_dim = SpatialDim("chunk-size", 25)
375+
input_chunk_size_dim = SpatialDim("input-chunk-size", 150)
376+
sliced_chunk_size_dim = SpatialDim("sliced-chunk-size", 20)
377+
378+
"_input_chunked": {
379+
"class": "window",
380+
"from": "source",
381+
"out_spatial_dim": chunked_time_dim,
382+
"stride": 120,
383+
"window_dim": input_chunk_size_dim,
384+
"window_left": 0,
385+
},
386+
387+
# audio_features is 16.000 Hz, i.e. 16.000 frames per sec, 16 frames per ms.
388+
layer /'source': # 100 frames per sec, 0.01 sec per frame, 10 ms per frame
389+
[B,T|'⌈(-199+time:var:extern_data:audio_features+-200)/160⌉'[B],F|F'mel_filterbank:feature-dense'(80)] float32
390+
layer /'_input_chunked': # 1.2 sec per frame
391+
[B,T|'⌈(-199+time:var:extern_data:audio_features+-200)/19200⌉'[B],
392+
'input-chunk-size'(150),F|F'mel_filterbank:feature-dense'(80)] float32
393+
"""
394+
corpus_entry = deps.corpus[segment_name]
395+
words = corpus_entry.orth.split()
396+
bpe_labels = deps.chunk_labels.get_data_by_seq_tag(segment_name, "data")
397+
bpe_labels_s = deps.chunk_bpe_vocab.get_seq_labels(bpe_labels)
398+
print(bpe_labels)
399+
print(bpe_labels_s)
400+
chunk_idx = 0
401+
cur_chunk_end_pos = deps.chunk_left_padding + deps.chunk_size
402+
cur_word = ""
403+
word_idx = 0
404+
res = []
405+
for label_idx in bpe_labels:
406+
if label_idx == deps.eoc_idx:
407+
chunk_idx += 1
408+
cur_chunk_end_pos += deps.chunk_stride
409+
continue
410+
assert word_idx < len(words), f"{bpe_labels_s!r} does not fit to {corpus_entry.orth!r}"
411+
label = deps.chunk_bpe_vocab.id_to_label(label_idx)
412+
if label.endswith("@@"):
413+
cur_word += label[:-2]
414+
continue
415+
cur_word += label
416+
assert (
417+
cur_word == words[word_idx]
418+
), f"{cur_word!r} != {words[word_idx]!r} in {bpe_labels_s!r} != {corpus_entry.orth!r}"
419+
print(f"end time {cur_chunk_end_pos}sec:", cur_word)
420+
res.append(cur_chunk_end_pos)
421+
word_idx += 1
422+
cur_word = ""
423+
assert word_idx == len(words) and not cur_word
424+
return res
351425

352426

353427
def main():
354428
"""main"""
355429
arg_parser = argparse.ArgumentParser()
356430
arg_parser.add_argument("--phone-alignments", required=True, help="From RASR")
357-
arg_parser.add_argument("--phone-alignment-ms-per-frame", type=float, default=10.0)
431+
arg_parser.add_argument("--phone-alignment-sec-per-frame", type=Decimal, default=Decimal("0.01"))
358432
arg_parser.add_argument("--allophone-file", required=True, help="From RASR")
359433
arg_parser.add_argument("--lexicon", required=True, help="XML")
360434
arg_parser.add_argument("--corpus", required=True, help="Bliss XML")
361-
arg_parser.add_argument("--labels-with-eoc", required=True, help="HDF dataset")
435+
arg_parser.add_argument("--chunk-labels", required=True, help="HDF dataset")
436+
arg_parser.add_argument("--eoc-idx", default=0, type=int, help="End-of-chunk idx")
437+
arg_parser.add_argument("--chunk-bpe-vocab", required=True, help="BPE vocab dict")
438+
arg_parser.add_argument(
439+
"--chunk-left-padding", type=Decimal, required=True, help="window_left in window layer, in sec"
440+
)
441+
arg_parser.add_argument("--chunk-stride", type=Decimal, required=True, help="stride in window layer, in sec")
442+
arg_parser.add_argument("--chunk-size", type=Decimal, required=True, help="window_dim in window layer, in sec")
362443
arg_parser.add_argument("--segment", nargs="*")
363-
arg_parser.add_argument("--bpe-vocab", required=True, help="BPE vocab dict")
364444
args = arg_parser.parse_args()
365445

366446
phone_alignments = open_file_archive(args.phone_alignments)
367447
phone_alignments.set_allophones(args.allophone_file)
368448

369449
lexicon = Lexicon(args.lexicon)
370450

371-
dataset = HDFDataset([args.labels_with_eoc])
451+
dataset = HDFDataset([args.chunk_labels])
372452
dataset.initialize()
373453
dataset.init_seq_order(epoch=1)
374454

375455
corpus = {}
376456
for item in iter_bliss(args.corpus):
377457
corpus[item.segment_name] = item
378458

379-
bpe_vocab = Vocabulary(args.bpe_vocab, unknown_label=None)
459+
bpe_vocab = Vocabulary(args.chunk_bpe_vocab, unknown_label=None)
380460

381461
deps = Deps(
382462
phone_alignments=phone_alignments,
383-
phone_alignment_ms_per_frame=args.phone_alignment_ms_per_frame,
463+
phone_alignment_sec_per_frame=args.phone_alignment_sec_per_frame,
384464
lexicon=lexicon,
385-
labels_with_eoc_hdf=dataset,
386465
corpus=corpus,
387-
bpe_vocab=bpe_vocab,
466+
chunk_labels=dataset,
467+
eoc_idx=args.eoc_idx,
468+
chunk_bpe_vocab=bpe_vocab,
469+
chunk_left_padding=args.chunk_left_padding,
470+
chunk_stride=args.chunk_stride,
471+
chunk_size=args.chunk_size,
388472
)
389473

390474
for segment_name in args.segment or corpus:

0 commit comments

Comments
 (0)