Skip to content

Commit de26d56

Browse files
Robert M OchshornRobert M Ochshorn
authored andcommitted
refactor multipass
1 parent e40bfbb commit de26d56

File tree

3 files changed

+173
-116
lines changed

3 files changed

+173
-116
lines changed

gentle/multipass.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import logging
2+
from multiprocessing.pool import ThreadPool as Pool
3+
import os
4+
import wave
5+
6+
from gentle import standard_kaldi
7+
from gentle import metasentence
8+
from gentle import language_model
9+
from gentle.paths import get_resource
10+
from gentle import diff_align
11+
12+
# XXX: refactor out somewhere
13+
proto_langdir = get_resource('PROTO_LANGDIR')
14+
vocab_path = os.path.join(proto_langdir, "graphdir/words.txt")
15+
with open(vocab_path) as f:
16+
vocab = metasentence.load_vocabulary(f)
17+
18+
def prepare_multipass(alignment):
19+
to_realign = []
20+
last_aligned_word = None
21+
cur_unaligned_words = []
22+
23+
for wd_idx,wd in enumerate(alignment):
24+
if wd['case'] == 'not-found-in-audio':
25+
cur_unaligned_words.append(wd)
26+
elif wd['case'] == 'success':
27+
if len(cur_unaligned_words) > 0:
28+
to_realign.append({
29+
"start": last_aligned_word,
30+
"end": wd,
31+
"words": cur_unaligned_words})
32+
cur_unaligned_words = []
33+
34+
last_aligned_word = wd
35+
36+
if len(cur_unaligned_words) > 0:
37+
to_realign.append({
38+
"start": last_aligned_word,
39+
"end": None,
40+
"words": cur_unaligned_words})
41+
42+
return to_realign
43+
44+
def realign(wavfile, alignment, ms, nthreads=4, progress_cb=None):
45+
to_realign = prepare_multipass(alignment)
46+
realignments = []
47+
48+
def realign(chunk):
49+
wav_obj = wave.open(wavfile, 'r')
50+
51+
start_t = (chunk["start"] or {"end": 0})["end"]
52+
end_t = chunk["end"]
53+
if end_t is None:
54+
end_t = wav_obj.getnframes() / float(wav_obj.getframerate())
55+
else:
56+
end_t = end_t["start"]
57+
58+
duration = end_t - start_t
59+
if duration < 0.01 or duration > 60:
60+
logging.debug("cannot realign %d words with duration %f" % (len(chunk['words']), duration))
61+
return
62+
63+
# Create a language model
64+
offset_offset = chunk['words'][0]['startOffset']
65+
chunk_len = chunk['words'][-1]['endOffset'] - offset_offset
66+
chunk_transcript = ms.raw_sentence[offset_offset:offset_offset+chunk_len].encode("utf-8")
67+
chunk_ms = metasentence.MetaSentence(chunk_transcript, vocab)
68+
chunk_ks = chunk_ms.get_kaldi_sequence()
69+
70+
chunk_gen_hclg_filename = language_model.make_bigram_language_model(chunk_ks, proto_langdir)
71+
k = standard_kaldi.Kaldi(
72+
get_resource('data/nnet_a_gpu_online'),
73+
chunk_gen_hclg_filename,
74+
proto_langdir)
75+
76+
wav_obj = wave.open(wavfile, 'r')
77+
wav_obj.setpos(int(start_t * wav_obj.getframerate()))
78+
buf = wav_obj.readframes(int(duration * wav_obj.getframerate()))
79+
80+
k.push_chunk(buf)
81+
ret = k.get_final()
82+
k.stop()
83+
84+
word_alignment = diff_align.align(ret, chunk_ms)
85+
86+
# Adjust startOffset, endOffset, and timing to match originals
87+
for wd in word_alignment:
88+
if wd.get("end"):
89+
# Apply timing offset
90+
wd['start'] += start_t
91+
wd['end'] += start_t
92+
93+
if wd.get("endOffset"):
94+
wd['startOffset'] += offset_offset
95+
wd['endOffset'] += offset_offset
96+
97+
# "chunk" should be replaced by "words"
98+
realignments.append({"chunk": chunk, "words": word_alignment})
99+
100+
if progress_cb is not None:
101+
progress_cb({"percent": len(realignments) / float(len(to_realign))})
102+
103+
pool = Pool(nthreads)
104+
pool.map(realign, to_realign)
105+
pool.close()
106+
107+
# Sub in the replacements
108+
o_words = alignment
109+
for ret in realignments:
110+
st_idx = o_words.index(ret["chunk"]["words"][0])
111+
end_idx= o_words.index(ret["chunk"]["words"][-1])+1
112+
logging.debug('splice in: "%s' % (str(ret["words"])))
113+
logging.debug('splice out: "%s' % (str(o_words[st_idx:end_idx])))
114+
o_words = o_words[:st_idx] + ret["words"] + o_words[end_idx:]
115+
116+
return o_words

serve.py

Lines changed: 50 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from gentle import diff_align
2525
from gentle import language_model
2626
from gentle import metasentence
27+
from gentle import multipass
2728
from gentle import standard_kaldi
2829
import gentle
2930

@@ -37,15 +38,32 @@ def render_GET(self, req):
3738
return json.dumps(self.status_dict)
3839

3940
class Transcriber():
40-
def __init__(self, data_dir, nthreads=4):
41+
def __init__(self, data_dir, nthreads=4, ntranscriptionthreads=2):
4142
self.data_dir = data_dir
4243
self.nthreads = nthreads
44+
self.ntranscriptionthreads = ntranscriptionthreads
4345

4446
proto_langdir = get_resource('PROTO_LANGDIR')
4547
vocab_path = os.path.join(proto_langdir, "graphdir/words.txt")
4648
with open(vocab_path) as f:
4749
self.vocab = metasentence.load_vocabulary(f)
4850

51+
# load kaldi instances for full transcription
52+
gen_hclg_filename = get_resource('data/graph/HCLG.fst')
53+
54+
if os.path.exists(gen_hclg_filename) and self.ntranscriptionthreads > 0:
55+
proto_langdir = get_resource('PROTO_LANGDIR')
56+
nnet_gpu_path = get_resource('data/nnet_a_gpu_online')
57+
58+
kaldi_queue = Queue()
59+
for i in range(self.ntranscriptionthreads):
60+
kaldi_queue.put(standard_kaldi.Kaldi(
61+
nnet_gpu_path,
62+
gen_hclg_filename,
63+
proto_langdir)
64+
)
65+
self.full_transcriber = MultiThreadedTranscriber(kaldi_queue, nthreads=self.ntranscriptionthreads)
66+
4967
self._status_dicts = {}
5068

5169
def get_status(self, uid):
@@ -99,133 +117,50 @@ def transcribe(self, uid, transcript, audio, async):
99117
status['duration'] = wav_obj.getnframes() / float(wav_obj.getframerate())
100118
status['status'] = 'TRANSCRIBING'
101119

120+
def on_progress(p):
121+
for k,v in p.items():
122+
status[k] = v
123+
102124
if len(transcript.strip()) > 0:
103125
ms = metasentence.MetaSentence(transcript, self.vocab)
104126
ks = ms.get_kaldi_sequence()
105127
gen_hclg_filename = language_model.make_bigram_language_model(ks, proto_langdir)
106-
else:
107-
# TODO: We shouldn't load full language models every time;
108-
# these should stay in-memory.
109-
gen_hclg_filename = get_resource('data/graph/HCLG.fst')
110-
if not os.path.exists(gen_hclg_filename):
111-
status["status"] = "ERROR"
112-
status["error"] = 'No transcript provided'
113-
return
114-
115-
kaldi_queue = Queue()
116-
for i in range(self.nthreads):
117-
kaldi_queue.put(standard_kaldi.Kaldi(
118-
get_resource('data/nnet_a_gpu_online'),
119-
gen_hclg_filename,
120-
proto_langdir)
121-
)
122128

123-
def on_progress(p):
124-
for k,v in p.items():
125-
status[k] = v
129+
kaldi_queue = Queue()
130+
for i in range(self.nthreads):
131+
kaldi_queue.put(standard_kaldi.Kaldi(
132+
get_resource('data/nnet_a_gpu_online'),
133+
gen_hclg_filename,
134+
proto_langdir)
135+
)
126136

127-
mtt = MultiThreadedTranscriber(kaldi_queue, nthreads=self.nthreads)
128-
words = mtt.transcribe(wavfile, progress_cb=on_progress)
137+
mtt = MultiThreadedTranscriber(kaldi_queue, nthreads=self.nthreads)
138+
elif hasattr(self, 'full_transcriber'):
139+
mtt = self.full_transcriber
140+
else:
141+
status['status'] = 'ERROR'
142+
status['error'] = 'No transcript provided and no language model for full transcription'
143+
return
129144

130-
# Clear queue
131-
for i in range(self.nthreads):
132-
k = kaldi_queue.get()
133-
k.stop()
145+
words = mtt.transcribe(wavfile, progress_cb=on_progress)
134146

135147
output = {}
136148
if len(transcript.strip()) > 0:
149+
# Clear queue (would this be gc'ed?)
150+
for i in range(self.nthreads):
151+
k = kaldi_queue.get()
152+
k.stop()
153+
137154
# Align words
138155
output['words'] = diff_align.align(words, ms)
139156
output['transcript'] = transcript
140157

141158
# Perform a second-pass with unaligned words
142159
logging.info("%d unaligned words (of %d)" % (len([X for X in output['words'] if X.get("case") == "not-found-in-audio"]), len(output['words'])))
143160

144-
to_realign = []
145-
last_aligned_word = None
146-
cur_unaligned_words = []
147-
148-
for wd_idx,wd in enumerate(output['words']):
149-
if wd['case'] == 'not-found-in-audio':
150-
cur_unaligned_words.append(wd)
151-
elif wd['case'] == 'success':
152-
if len(cur_unaligned_words) > 0:
153-
to_realign.append({
154-
"start": last_aligned_word,
155-
"end": wd,
156-
"words": cur_unaligned_words})
157-
cur_unaligned_words = []
158-
159-
last_aligned_word = wd
160-
161-
if len(cur_unaligned_words) > 0:
162-
to_realign.append({
163-
"start": last_aligned_word,
164-
"end": None,
165-
"words": cur_unaligned_words})
166-
167-
realignments = []
168-
169-
def realign(chunk):
170-
start_t = (chunk["start"] or {"end": 0})["end"]
171-
end_t = (chunk["end"] or {"start": status["duration"]})["start"]
172-
duration = end_t - start_t
173-
if duration < 0.01 or duration > 60:
174-
logging.info("cannot realign %d words with duration %f" % (len(chunk['words']), duration))
175-
return
176-
177-
# Create a language model
178-
offset_offset = chunk['words'][0]['startOffset']
179-
chunk_len = chunk['words'][-1]['endOffset'] - offset_offset
180-
chunk_transcript = ms.raw_sentence[offset_offset:offset_offset+chunk_len].encode("utf-8")
181-
chunk_ms = metasentence.MetaSentence(chunk_transcript, self.vocab)
182-
chunk_ks = chunk_ms.get_kaldi_sequence()
183-
184-
chunk_gen_hclg_filename = language_model.make_bigram_language_model(chunk_ks, proto_langdir)
185-
186-
k = standard_kaldi.Kaldi(
187-
get_resource('data/nnet_a_gpu_online'),
188-
chunk_gen_hclg_filename,
189-
proto_langdir)
190-
191-
wav_obj = wave.open(wavfile, 'r')
192-
wav_obj.setpos(int(start_t * wav_obj.getframerate()))
193-
buf = wav_obj.readframes(int(duration * wav_obj.getframerate()))
194-
195-
k.push_chunk(buf)
196-
ret = k.get_final()
197-
k.stop()
161+
status['status'] = 'ALIGNING'
198162

199-
word_alignment = diff_align.align(ret, chunk_ms)
200-
201-
# Adjust startOffset, endOffset, and timing to match originals
202-
for wd in word_alignment:
203-
if wd.get("end"):
204-
# Apply timing offset
205-
wd['start'] += start_t
206-
wd['end'] += start_t
207-
208-
if wd.get("endOffset"):
209-
wd['startOffset'] += offset_offset
210-
wd['endOffset'] += offset_offset
211-
212-
# "chunk" should be replaced by "words"
213-
realignments.append({"chunk": chunk, "words": word_alignment})
214-
215-
pool = Pool(self.nthreads)
216-
pool.map(realign, to_realign)
217-
pool.close()
218-
219-
# Sub in the replacements
220-
o_words = output['words']
221-
for ret in realignments:
222-
st_idx = o_words.index(ret["chunk"]["words"][0])
223-
end_idx= o_words.index(ret["chunk"]["words"][-1])+1
224-
logging.debug('splice in: "%s' % (str(ret["words"])))
225-
logging.debug('splice out: "%s' % (str(o_words[st_idx:end_idx])))
226-
o_words = o_words[:st_idx] + ret["words"] + o_words[end_idx:]
227-
228-
output['words'] = o_words
163+
output['words'] = multipass.realign(wavfile, output['words'], ms, nthreads=self.nthreads, progress_cb=on_progress)
229164

230165
logging.info("after 2nd pass: %d unaligned words (of %d)" % (len([X for X in output['words'] if X.get("case") == "not-found-in-audio"]), len(output['words'])))
231166

@@ -361,7 +296,7 @@ def make_transcription_alignment(trans):
361296
trans["words"] = words
362297
return trans
363298

364-
def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, data_dir=get_datadir('webdata')):
299+
def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, ntranscriptionthreads=2, data_dir=get_datadir('webdata')):
365300
logging.info("SERVE %d, %s, %d", port, interface, installSignalHandlers)
366301

367302
if not os.path.exists(data_dir):
@@ -377,7 +312,7 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
377312
f.putChild('status.html', File(get_resource('www/status.html')))
378313
f.putChild('preloader.gif', File(get_resource('www/preloader.gif')))
379314

380-
trans = Transcriber(data_dir, nthreads=nthreads)
315+
trans = Transcriber(data_dir, nthreads=nthreads, ntranscriptionthreads=ntranscriptionthreads)
381316
trans_ctrl = TranscriptionsController(trans)
382317
f.putChild('transcriptions', trans_ctrl)
383318

@@ -403,6 +338,8 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
403338
help='port number to run http server on')
404339
parser.add_argument('--nthreads', default=multiprocessing.cpu_count(), type=int,
405340
help='number of alignment threads')
341+
parser.add_argument('--ntranscriptionthreads', default=2, type=int,
342+
help='number of full-transcription threads (memory intensive)')
406343
parser.add_argument('--log', default="INFO",
407344
help='the log level (DEBUG, INFO, WARNING, ERROR, or CRITICAL)')
408345

@@ -414,4 +351,4 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
414351
logging.info('gentle %s' % (gentle.__version__))
415352
logging.info('listening at %s:%d\n' % (args.host, args.port))
416353

417-
serve(args.port, args.host, nthreads=args.nthreads, installSignalHandlers=1)
354+
serve(args.port, args.host, nthreads=args.nthreads, ntranscriptionthreads=args.ntranscriptionthreads, installSignalHandlers=1)

www/view_alignment.html

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,12 @@ <h1 class="home"><a href="/">Gentle</a></h1>
333333

334334
status_init = true;
335335
}
336-
337-
if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) {
336+
if(ret.status !== "TRANSCRIBING") {
337+
if(ret.percent) {
338+
$status_pro.value = (100*ret.percent);
339+
}
340+
}
341+
else if(ret.percent && (status_log.length == 0 || status_log[status_log.length-1].percent+0.0001 < ret.percent)) {
338342
// New entry
339343
var $entry = document.createElement("div");
340344
$entry.className = "entry";
@@ -369,7 +373,7 @@ <h1 class="home"><a href="/">Gentle</a></h1>
369373
if (ret.status == 'ERROR') {
370374
$preloader.style.visibility = 'hidden';
371375
$trans.innerHTML = '<b>' + ret.status + ': ' + ret.error + '</b>';
372-
} else if (ret.status == 'TRANSCRIBING') {
376+
} else if (ret.status == 'TRANSCRIBING' || ret.status == 'ALIGNING') {
373377
$preloader.style.visibility = 'visible';
374378
render_status(ret);
375379
setTimeout(update, 2000);

0 commit comments

Comments
 (0)