diff --git a/lexicon/bpe.py b/lexicon/bpe.py index 563f0020..d786ecc3 100644 --- a/lexicon/bpe.py +++ b/lexicon/bpe.py @@ -1,6 +1,7 @@ __all__ = ["CreateBPELexiconJob"] import subprocess as sp +import logging import os import sys from typing import List, Optional, Set, Union @@ -19,6 +20,8 @@ class CreateBPELexiconJob(Job): This job is still in experimental state, and only tested with Flashlight BPE decoding """ + __sis_hash_exclude__ = {"skip_unk_lemmas": False, "add_all_bpe_phonemes": True, "additional_words": None} + def __init__( self, base_lexicon_path: tk.Path, @@ -28,6 +31,9 @@ def __init__( unk_label: str = "UNK", vocab_blacklist: Optional[Union[List[str], Set[str]]] = None, keep_special_lemmas: bool = True, + skip_unk_lemmas: bool = False, + add_all_bpe_phonemes: bool = True, + additional_words: Optional[tk.Path] = None, ): """ :param base_lexicon_path: base lexicon (can be phoneme based) to take the lemmas from @@ -41,6 +47,12 @@ def __init__( usually yes for RASR search and no for Flashlight search. The phonemes of the special lemmas will also be kept, therefore make sure there is no overlap with the BPE vocab. + :param skip_unk_lemmas: Whether to simply skip lemmas that are not part of the BPE vocabulary. + Useful if you set vocab_blacklist. + :param add_all_bpe_phonemes: If set to True, all BPE tokens will be added to lexicon as phonemes, + otherwise, only tokens that appear in the base lexicon will be added to the output lexicon. + :param additional_words: Aside from the vocabulary specified in base_lexicon, we might want to convert some other words, + e.g. untranslatable words by a g2p model in case of g2p-augmented lexicon """ self.base_lexicon_path = base_lexicon_path self.bpe_codes = bpe_codes @@ -53,6 +65,9 @@ def __init__( # convert list to set for faster "in" check self.vocab_blacklist = set(vocab_blacklist) self.keep_special_lemmas = keep_special_lemmas + self.skip_unk_lemmas = skip_unk_lemmas + self.add_all_bpe_phonemes = add_all_bpe_phonemes + self.additional_words = additional_words self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True) @@ -85,14 +100,26 @@ def _fill_vocab_and_lexicon(self): vocab_file.write(symbol + " -1\n") symbol = symbol.replace(".", "_") vocab.add(symbol) - lexicon.add_phoneme(symbol.replace(".", "_")) + if self.add_all_bpe_phonemes: + lexicon.add_phoneme(symbol.replace(".", "_")) return vocab, lexicon + def _fill_additional_words(self): + if self.additional_words is not None: + with util.uopen(self.additional_words.get_path(), "rt") as f: + res = {line.strip() for line in f} + else: + res = set() + return sorted(res) + def run(self): base_lexicon = Lexicon() base_lexicon.load(self.base_lexicon_path) + additional_words_list = self._fill_additional_words() + for w in additional_words_list: + base_lexicon.add_lemma(Lemma([w], None)) # add empty lemmata with only orth for additional words lm_tokens, special_lemmas = self._fill_lm_tokens(base_lexicon) with util.uopen("words", "wt") as f: @@ -129,13 +156,22 @@ def run(self): w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)} + used_vocab = set() for lemma in base_lexicon.lemmata: if lemma.special: continue for orth in lemma.orth: bpe_pron = " ".join([token if token in vocab else self.unk_label for token in w2b[orth].split()]) + if self.skip_unk_lemmas and self.unk_label in bpe_pron.split(): + logging.info(f"Lemma {orth} is skipped due to use of the BPE token for .") + continue + used_vocab.update(set(bpe_pron.split())) lexicon.add_lemma(Lemma([orth], [bpe_pron.replace(".", "_")], lemma.synt, lemma.eval)) + if not self.add_all_bpe_phonemes: + for symbol in sorted(used_vocab): + lexicon.add_phoneme(symbol.replace(".", "_")) + elem = lexicon.to_xml() tree = ET.ElementTree(elem) util.write_xml(self.out_lexicon.get_path(), tree)