Skip to content

Commit 71be2db

Browse files
committed
get_language_model: just return hclg
1 parent d386f8c commit 71be2db

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

gentle/language_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def get_node_id(word):
5050
return output
5151

5252
def get_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'):
53-
"""Generates a language model to fit the text
53+
"""Generates a language model to fit the text.
54+
55+
Returns the filename of the generated language model FST.
56+
The caller is resposible for removing the generated file.
5457
5558
`proto_langdir` is a path to a directory containing prototype model data
5659
`kaldi_seq` is a list of words within kaldi's vocabulary.
@@ -62,22 +65,21 @@ def get_language_model(kaldi_seq, proto_langdir='PROTO_LANGDIR'):
6265
txt_fst_file.write(txt_fst)
6366
txt_fst_file.close()
6467

65-
out_dir = tempfile.mkdtemp()
66-
68+
hclg_filename = tempfile.mktemp(suffix='_HCLG.fst')
6769
try:
6870
subprocess.check_output([MKGRAPH_PATH,
6971
os.path.join(proto_langdir, 'langdir'),
7072
os.path.join(proto_langdir, 'modeldir'),
7173
txt_fst_file.name,
7274
os.path.join(proto_langdir, "graphdir/words.txt"),
73-
os.path.join(out_dir, 'HCLG.fst')])
75+
hclg_filename])
7476
except Exception, e:
75-
shutil.rmtree(out_dir)
77+
os.unlink(hclg_filename)
7678
raise e
7779
finally:
7880
os.unlink(txt_fst_file.name)
7981

80-
return out_dir
82+
return hclg_filename
8183

8284
if __name__=='__main__':
8385
import sys

gentle/language_model_transcribe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@ def lm_transcribe(audio_f, transcript, proto_langdir, nnet_dir,
2727

2828
ks = ms.get_kaldi_sequence()
2929

30-
gen_model_dir = language_model.get_language_model(ks, proto_langdir)
30+
gen_hclg_filename = language_model.get_language_model(ks, proto_langdir)
3131
try:
32-
gen_hclg_path = os.path.join(gen_model_dir, 'HCLG.fst')
33-
k = standard_kaldi.Kaldi(nnet_dir, gen_hclg_path, proto_langdir)
32+
k = standard_kaldi.Kaldi(nnet_dir, gen_hclg_filename, proto_langdir)
3433

3534
trans = standard_kaldi.transcribe(k, audio_f,
3635
partial_results_cb=partial_cb,
3736
partial_results_kwargs=partial_kwargs)
3837

3938
ret = diff_align.align(trans["words"], ms)
4039
finally:
41-
shutil.rmtree(gen_model_dir)
40+
os.unlink(gen_hclg_filename)
4241

4342
return {
4443
"transcript": transcript,

0 commit comments

Comments
 (0)