|
1 | 1 | import logging |
| 2 | +import math |
2 | 3 | import os |
| 4 | +import shutil |
3 | 5 | import subprocess |
4 | 6 | import sys |
5 | 7 | import tempfile |
6 | 8 |
|
7 | 9 | from paths import get_binary |
8 | | -from generate_wp import language_model_from_word_sequence |
| 10 | +from metasentence import MetaSentence |
9 | 11 |
|
10 | 12 | MKGRAPH_PATH = get_binary("mkgraph") |
11 | 13 |
|
12 | | -def get_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'): |
13 | | - """Generates a language model to fit the text |
| 14 | +def make_bigram_lm_fst(word_sequence): |
| 15 | + ''' |
| 16 | + Use the given token sequence to make a bigram language model |
| 17 | + in OpenFST plain text format. |
| 18 | + ''' |
| 19 | + word_sequence = ['[oov]', '[oov]'] + word_sequence + ['[oov]'] |
| 20 | + |
| 21 | + bigrams = {} |
| 22 | + prev_word = word_sequence[0] |
| 23 | + for word in word_sequence[1:]: |
| 24 | + bigrams.setdefault(prev_word, set()).add(word) |
| 25 | + prev_word = word |
| 26 | + |
| 27 | + node_ids = {} |
| 28 | + def get_node_id(word): |
| 29 | + node_id = node_ids.get(word, len(node_ids) + 1) |
| 30 | + node_ids[word] = node_id |
| 31 | + return node_id |
| 32 | + |
| 33 | + output = "" |
| 34 | + for from_word in sorted(bigrams.keys()): |
| 35 | + from_id = get_node_id(from_word) |
| 36 | + |
| 37 | + successors = bigrams[from_word] |
| 38 | + if len(successors) > 0: |
| 39 | + weight = -math.log(1.0 / len(successors)) |
| 40 | + else: |
| 41 | + weight = 0 |
| 42 | + |
| 43 | + for to_word in sorted(successors): |
| 44 | + to_id = get_node_id(to_word) |
| 45 | + output += '%d %d %s %s %f' % (from_id, to_id, to_word, to_word, weight) |
| 46 | + output += "\n" |
| 47 | + |
| 48 | + output += "%d 0\n" % (len(node_ids)) |
| 49 | + |
| 50 | + return output |
| 51 | + |
| 52 | +def make_bigram_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'): |
| 53 | + """Generates a language model to fit the text. |
| 54 | +
|
| 55 | + Returns the filename of the generated language model FST. |
| 56 | + The caller is resposible for removing the generated file. |
14 | 57 |
|
15 | 58 | `proto_langdir` is a path to a directory containing prototype model data |
16 | 59 | `kaldi_seq` is a list of words within kaldi's vocabulary. |
17 | 60 | """ |
18 | 61 |
|
19 | | - # Create a language model directory |
20 | | - lang_model_dir = tempfile.mkdtemp() |
21 | | - logging.info('saving language model to %s', lang_model_dir) |
22 | | - |
23 | | - # Symlink in necessary files from the prototype directory |
24 | | - for dirpath, dirnames, filenames in os.walk(proto_langdir, followlinks=True): |
25 | | - for dirname in dirnames: |
26 | | - relpath = os.path.relpath(os.path.join(dirpath, dirname), proto_langdir) |
27 | | - os.makedirs(os.path.join(lang_model_dir, relpath)) |
28 | | - for filename in filenames: |
29 | | - abspath = os.path.abspath(os.path.join(dirpath, filename)) |
30 | | - relpath = os.path.relpath(os.path.join(dirpath, filename), proto_langdir) |
31 | | - dstpath = os.path.join(lang_model_dir, relpath) |
32 | | - os.symlink(abspath, dstpath) |
33 | | - |
34 | 62 | # Generate a textual FST |
35 | | - txt_fst = language_model_from_word_sequence(kaldi_seq) |
36 | | - txt_fst_file = os.path.join(lang_model_dir, 'G.txt') |
37 | | - open(txt_fst_file, 'w').write(txt_fst) |
| 63 | + txt_fst = make_bigram_lm_fst(kaldi_seq) |
| 64 | + txt_fst_file = tempfile.NamedTemporaryFile(delete=False) |
| 65 | + txt_fst_file.write(txt_fst) |
| 66 | + txt_fst_file.close() |
38 | 67 |
|
39 | | - words_file = os.path.join(proto_langdir, "graphdir/words.txt") |
40 | | - subprocess.check_output([MKGRAPH_PATH, |
41 | | - os.path.join(lang_model_dir, 'langdir'), |
42 | | - os.path.join(lang_model_dir, 'modeldir'), |
43 | | - txt_fst_file, |
44 | | - words_file, |
45 | | - os.path.join(lang_model_dir, 'graphdir', 'HCLG.fst')]) |
| 68 | + hclg_filename = tempfile.mktemp(suffix='_HCLG.fst') |
| 69 | + try: |
| 70 | + subprocess.check_output([MKGRAPH_PATH, |
| 71 | + proto_langdir, |
| 72 | + txt_fst_file.name, |
| 73 | + hclg_filename]) |
| 74 | + except Exception, e: |
| 75 | + os.unlink(hclg_filename) |
| 76 | + raise e |
| 77 | + finally: |
| 78 | + os.unlink(txt_fst_file.name) |
46 | 79 |
|
47 | | - # Return the language model directory |
48 | | - return lang_model_dir |
| 80 | + return hclg_filename |
49 | 81 |
|
50 | 82 | if __name__=='__main__': |
51 | 83 | import sys |
52 | | - get_language_model(open(sys.argv[1]).read()) |
| 84 | + make_bigram_language_model(open(sys.argv[1]).read()) |
0 commit comments