Skip to content

Commit abfa0db

Browse files
committed
Merge pull request #38 from maxhawkins/dupe
clean up language model generation
2 parents 06dbc62 + 50c2f24 commit abfa0db

File tree

5 files changed

+3247
-3293
lines changed

5 files changed

+3247
-3293
lines changed

gentle/generate_wp.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

gentle/language_model.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,84 @@
11
import logging
2+
import math
23
import os
4+
import shutil
35
import subprocess
46
import sys
57
import tempfile
68

79
from paths import get_binary
8-
from generate_wp import language_model_from_word_sequence
10+
from metasentence import MetaSentence
911

1012
MKGRAPH_PATH = get_binary("mkgraph")
1113

12-
def get_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'):
13-
"""Generates a language model to fit the text
14+
def make_bigram_lm_fst(word_sequence):
15+
'''
16+
Use the given token sequence to make a bigram language model
17+
in OpenFST plain text format.
18+
'''
19+
word_sequence = ['[oov]', '[oov]'] + word_sequence + ['[oov]']
20+
21+
bigrams = {}
22+
prev_word = word_sequence[0]
23+
for word in word_sequence[1:]:
24+
bigrams.setdefault(prev_word, set()).add(word)
25+
prev_word = word
26+
27+
node_ids = {}
28+
def get_node_id(word):
29+
node_id = node_ids.get(word, len(node_ids) + 1)
30+
node_ids[word] = node_id
31+
return node_id
32+
33+
output = ""
34+
for from_word in sorted(bigrams.keys()):
35+
from_id = get_node_id(from_word)
36+
37+
successors = bigrams[from_word]
38+
if len(successors) > 0:
39+
weight = -math.log(1.0 / len(successors))
40+
else:
41+
weight = 0
42+
43+
for to_word in sorted(successors):
44+
to_id = get_node_id(to_word)
45+
output += '%d %d %s %s %f' % (from_id, to_id, to_word, to_word, weight)
46+
output += "\n"
47+
48+
output += "%d 0\n" % (len(node_ids))
49+
50+
return output
51+
52+
def make_bigram_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'):
53+
"""Generates a language model to fit the text.
54+
55+
Returns the filename of the generated language model FST.
56+
The caller is resposible for removing the generated file.
1457
1558
`proto_langdir` is a path to a directory containing prototype model data
1659
`kaldi_seq` is a list of words within kaldi's vocabulary.
1760
"""
1861

19-
# Create a language model directory
20-
lang_model_dir = tempfile.mkdtemp()
21-
logging.info('saving language model to %s', lang_model_dir)
22-
23-
# Symlink in necessary files from the prototype directory
24-
for dirpath, dirnames, filenames in os.walk(proto_langdir, followlinks=True):
25-
for dirname in dirnames:
26-
relpath = os.path.relpath(os.path.join(dirpath, dirname), proto_langdir)
27-
os.makedirs(os.path.join(lang_model_dir, relpath))
28-
for filename in filenames:
29-
abspath = os.path.abspath(os.path.join(dirpath, filename))
30-
relpath = os.path.relpath(os.path.join(dirpath, filename), proto_langdir)
31-
dstpath = os.path.join(lang_model_dir, relpath)
32-
os.symlink(abspath, dstpath)
33-
3462
# Generate a textual FST
35-
txt_fst = language_model_from_word_sequence(kaldi_seq)
36-
txt_fst_file = os.path.join(lang_model_dir, 'G.txt')
37-
open(txt_fst_file, 'w').write(txt_fst)
63+
txt_fst = make_bigram_lm_fst(kaldi_seq)
64+
txt_fst_file = tempfile.NamedTemporaryFile(delete=False)
65+
txt_fst_file.write(txt_fst)
66+
txt_fst_file.close()
3867

39-
words_file = os.path.join(proto_langdir, "graphdir/words.txt")
40-
subprocess.check_output([MKGRAPH_PATH,
41-
os.path.join(lang_model_dir, 'langdir'),
42-
os.path.join(lang_model_dir, 'modeldir'),
43-
txt_fst_file,
44-
words_file,
45-
os.path.join(lang_model_dir, 'graphdir', 'HCLG.fst')])
68+
hclg_filename = tempfile.mktemp(suffix='_HCLG.fst')
69+
try:
70+
subprocess.check_output([MKGRAPH_PATH,
71+
proto_langdir,
72+
txt_fst_file.name,
73+
hclg_filename])
74+
except Exception, e:
75+
os.unlink(hclg_filename)
76+
raise e
77+
finally:
78+
os.unlink(txt_fst_file.name)
4679

47-
# Return the language model directory
48-
return lang_model_dir
80+
return hclg_filename
4981

5082
if __name__=='__main__':
5183
import sys
52-
get_language_model(open(sys.argv[1]).read())
84+
make_bigram_language_model(open(sys.argv[1]).read())

gentle/language_model_transcribe.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@ def lm_transcribe(audio_f, transcript, proto_langdir, nnet_dir,
2727

2828
ks = ms.get_kaldi_sequence()
2929

30-
gen_model_dir = language_model.get_language_model(ks, proto_langdir)
30+
gen_hclg_filename = language_model.make_bigram_language_model(ks, proto_langdir)
31+
try:
32+
k = standard_kaldi.Kaldi(nnet_dir, gen_hclg_filename, proto_langdir)
3133

32-
gen_hclg_path = os.path.join(gen_model_dir, 'graphdir', 'HCLG.fst')
33-
k = standard_kaldi.Kaldi(nnet_dir, gen_hclg_path, proto_langdir)
34+
trans = standard_kaldi.transcribe(k, audio_f,
35+
partial_results_cb=partial_cb,
36+
partial_results_kwargs=partial_kwargs)
3437

35-
trans = standard_kaldi.transcribe(k, audio_f,
36-
partial_results_cb=partial_cb,
37-
partial_results_kwargs=partial_kwargs)
38-
39-
ret = diff_align.align(trans["words"], ms)
40-
41-
shutil.rmtree(gen_model_dir)
38+
ret = diff_align.align(trans["words"], ms)
39+
finally:
40+
os.unlink(gen_hclg_filename)
4241

4342
return {
4443
"transcript": transcript,

mkgraph.cc

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ int main(int argc, char *argv[]) {
1313
using namespace fst;
1414
using fst::script::ArcSort;
1515
try {
16-
const char *usage = "Usage: ./mkgraph [options] <lang-dir> <model-dir> <grammar-fst> <words-txt> <out-fst>\n";
16+
const char *usage = "Usage: ./mkgraph [options] <proto-dir> <grammar-fst> <out-fst>\n";
1717

1818
ParseOptions po(usage);
1919
po.Read(argc, argv);
20-
if (po.NumArgs() != 5) {
20+
if (po.NumArgs() != 3) {
2121
po.PrintUsage();
2222
return 1;
2323
}
@@ -27,17 +27,16 @@ int main(int argc, char *argv[]) {
2727
float self_loop_scale = 0.1;
2828
bool reverse = false;
2929

30-
std::string lang_dir = po.GetArg(1),
31-
model_dir = po.GetArg(2),
32-
grammar_fst_filename = po.GetArg(3),
33-
words_filename = po.GetArg(4),
34-
out_filename = po.GetArg(5);
35-
36-
std::string lang_fst_filename = lang_dir + "/L.fst",
37-
lang_disambig_fst_filename = lang_dir + "/L_disambig.fst",
38-
disambig_phones_filename = lang_dir + "/phones/disambig.int",
39-
model_filename = model_dir + "/final.mdl",
40-
tree_filename = model_dir + "/tree";
30+
std::string proto_dir = po.GetArg(1),
31+
grammar_fst_filename = po.GetArg(2),
32+
out_filename = po.GetArg(3);
33+
34+
std::string lang_fst_filename = proto_dir + "/langdir/L.fst",
35+
lang_disambig_fst_filename = proto_dir + "/langdir/L_disambig.fst",
36+
disambig_phones_filename = proto_dir + "/langdir/phones/disambig.int",
37+
model_filename = proto_dir + "/modeldir/final.mdl",
38+
tree_filename = proto_dir + "/modeldir/tree",
39+
words_filename = proto_dir + "/graphdir/words.txt";
4140

4241
if (!std::ifstream(lang_fst_filename.c_str())) {
4342
std::cerr << "expected " << lang_fst_filename << " to exist" << std::endl;

0 commit comments

Comments
 (0)