|
| 1 | +__all__ = ["CreateBPELexiconJob"] |
| 2 | + |
| 3 | +import subprocess as sp |
| 4 | +import os |
| 5 | +import sys |
| 6 | +from typing import List, Optional, Set, Union |
| 7 | +import xml.etree.ElementTree as ET |
| 8 | + |
| 9 | +from sisyphus import Job, Task, tk |
| 10 | + |
| 11 | +from i6_core.lib.lexicon import Lexicon, Lemma |
| 12 | +import i6_core.util as util |
| 13 | + |
| 14 | + |
| 15 | +class CreateBPELexiconJob(Job): |
| 16 | + """ |
| 17 | + In a Bliss lexicon replace the phonetic representation with a BPE decomposition of the words that can be used e.g, for lexicon constrained BPE search. |
| 18 | +
|
| 19 | + This job is still in experimental state, and only tested with Flashlight BPE decoding |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + base_lexicon_path: tk.Path, |
| 25 | + bpe_codes: tk.Path, |
| 26 | + bpe_vocab: tk.Path, |
| 27 | + subword_nmt_repo: tk.Path, |
| 28 | + unk_label: str = "UNK", |
| 29 | + vocab_blacklist: Optional[Union[List[str], Set[str]]] = None, |
| 30 | + keep_special_lemmas: bool = True, |
| 31 | + ): |
| 32 | + """ |
| 33 | + :param base_lexicon_path: base lexicon (can be phoneme based) to take the lemmas from |
| 34 | + :param bpe_codes: bpe codes from the ReturnnTrainBPEJob |
| 35 | + :param bpe_vocab: vocab file to limit which bpe splits can be created |
| 36 | + :param subword_nmt_repo: cloned repository |
| 37 | + :param unk_label: unknown label, used in case a BPE token is created that is not in the vocab. |
| 38 | + :param vocab_blacklist: which bpe_vocab entries not to load into the "phoneme/bpe-token" inventory |
| 39 | + e.g. remove "<s>" and "</s>" |
| 40 | + :param keep_special_lemmas: If special lemmas should be kept, |
| 41 | + usually yes for RASR search and no for Flashlight search. |
| 42 | + The phonemes of the special lemmas will also be kept, therefore |
| 43 | + make sure there is no overlap with the BPE vocab. |
| 44 | + """ |
| 45 | + self.base_lexicon_path = base_lexicon_path |
| 46 | + self.bpe_codes = bpe_codes |
| 47 | + self.bpe_vocab = bpe_vocab |
| 48 | + self.subword_nmt_repo = subword_nmt_repo |
| 49 | + self.unk_label = unk_label |
| 50 | + if vocab_blacklist is None: |
| 51 | + self.vocab_blacklist = set() |
| 52 | + else: |
| 53 | + # convert list to set for faster "in" check |
| 54 | + self.vocab_blacklist = set(vocab_blacklist) |
| 55 | + self.keep_special_lemmas = keep_special_lemmas |
| 56 | + |
| 57 | + self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True) |
| 58 | + |
| 59 | + def tasks(self): |
| 60 | + yield Task("run", resume="run", mini_task=True) |
| 61 | + |
| 62 | + def _fill_lm_tokens(self, base_lexicon: Lexicon): |
| 63 | + lm_tokens = set() |
| 64 | + special_lemmas = [] |
| 65 | + for lemma in base_lexicon.lemmata: |
| 66 | + if lemma.special is None: |
| 67 | + lm_tokens.update(lemma.orth) |
| 68 | + else: |
| 69 | + special_lemmas.append(lemma) |
| 70 | + |
| 71 | + return sorted(lm_tokens), special_lemmas |
| 72 | + |
| 73 | + def _fill_vocab_and_lexicon(self): |
| 74 | + lexicon = Lexicon() |
| 75 | + vocab = set() |
| 76 | + with util.uopen(self.bpe_vocab.get_path(), "rt") as f, util.uopen("fake_count_vocab.txt", "wt") as vocab_file: |
| 77 | + for line in f: |
| 78 | + line = line.strip() |
| 79 | + if line == "{" or line == "}": |
| 80 | + continue |
| 81 | + # a line is e.g. '"phon": 0,' and we want to get 'phon' only |
| 82 | + symbol = line.split(":")[0][1:-1] |
| 83 | + if symbol not in self.vocab_blacklist: |
| 84 | + # Fake count vocab filled with -1 so that all merges possible are done |
| 85 | + vocab_file.write(symbol + " -1\n") |
| 86 | + symbol = symbol.replace(".", "_") |
| 87 | + vocab.add(symbol) |
| 88 | + lexicon.add_phoneme(symbol.replace(".", "_")) |
| 89 | + |
| 90 | + return vocab, lexicon |
| 91 | + |
| 92 | + def run(self): |
| 93 | + base_lexicon = Lexicon() |
| 94 | + base_lexicon.load(self.base_lexicon_path) |
| 95 | + |
| 96 | + lm_tokens, special_lemmas = self._fill_lm_tokens(base_lexicon) |
| 97 | + |
| 98 | + with util.uopen("words", "wt") as f: |
| 99 | + for t in lm_tokens: |
| 100 | + f.write(f"{t}\n") |
| 101 | + |
| 102 | + vocab, lexicon = self._fill_vocab_and_lexicon() |
| 103 | + |
| 104 | + # add special lemmas back to lexicon |
| 105 | + if self.keep_special_lemmas is True: |
| 106 | + for special_lemma in special_lemmas: |
| 107 | + for pronunciation_variant in special_lemma.phon: |
| 108 | + for phoneme in pronunciation_variant.split(): |
| 109 | + lexicon.add_phoneme(phoneme, variation=base_lexicon.phonemes[phoneme]) |
| 110 | + lexicon.add_lemma(special_lemma) |
| 111 | + |
| 112 | + apply_binary = os.path.join(self.subword_nmt_repo.get_path(), "apply_bpe.py") |
| 113 | + args = [ |
| 114 | + sys.executable, |
| 115 | + apply_binary, |
| 116 | + "--input", |
| 117 | + "words", |
| 118 | + "--codes", |
| 119 | + self.bpe_codes.get_path(), |
| 120 | + "--vocabulary", |
| 121 | + "fake_count_vocab.txt", |
| 122 | + "--output", |
| 123 | + "bpes", |
| 124 | + ] |
| 125 | + sp.run(args, check=True) |
| 126 | + |
| 127 | + with util.uopen("bpes", "rt") as bpe_file: |
| 128 | + bpe_tokens = [line.strip() for line in bpe_file] |
| 129 | + |
| 130 | + w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)} |
| 131 | + |
| 132 | + for lemma in base_lexicon.lemmata: |
| 133 | + if lemma.special: |
| 134 | + continue |
| 135 | + for orth in lemma.orth: |
| 136 | + bpe_pron = " ".join([token if token in vocab else self.unk_label for token in w2b[orth].split()]) |
| 137 | + lexicon.add_lemma(Lemma([orth], [bpe_pron.replace(".", "_")], lemma.synt, lemma.eval)) |
| 138 | + |
| 139 | + elem = lexicon.to_xml() |
| 140 | + tree = ET.ElementTree(elem) |
| 141 | + util.write_xml(self.out_lexicon.get_path(), tree) |
0 commit comments