2424from gentle import diff_align
2525from gentle import language_model
2626from gentle import metasentence
27+ from gentle import multipass
2728from gentle import standard_kaldi
2829import gentle
2930
@@ -37,15 +38,32 @@ def render_GET(self, req):
3738 return json .dumps (self .status_dict )
3839
3940class Transcriber ():
40- def __init__ (self , data_dir , nthreads = 4 ):
41+ def __init__ (self , data_dir , nthreads = 4 , ntranscriptionthreads = 2 ):
4142 self .data_dir = data_dir
4243 self .nthreads = nthreads
44+ self .ntranscriptionthreads = ntranscriptionthreads
4345
4446 proto_langdir = get_resource ('PROTO_LANGDIR' )
4547 vocab_path = os .path .join (proto_langdir , "graphdir/words.txt" )
4648 with open (vocab_path ) as f :
4749 self .vocab = metasentence .load_vocabulary (f )
4850
51+ # load kaldi instances for full transcription
52+ gen_hclg_filename = get_resource ('data/graph/HCLG.fst' )
53+
54+ if os .path .exists (gen_hclg_filename ) and self .ntranscriptionthreads > 0 :
55+ proto_langdir = get_resource ('PROTO_LANGDIR' )
56+ nnet_gpu_path = get_resource ('data/nnet_a_gpu_online' )
57+
58+ kaldi_queue = Queue ()
59+ for i in range (self .ntranscriptionthreads ):
60+ kaldi_queue .put (standard_kaldi .Kaldi (
61+ nnet_gpu_path ,
62+ gen_hclg_filename ,
63+ proto_langdir )
64+ )
65+ self .full_transcriber = MultiThreadedTranscriber (kaldi_queue , nthreads = self .ntranscriptionthreads )
66+
4967 self ._status_dicts = {}
5068
5169 def get_status (self , uid ):
@@ -99,133 +117,50 @@ def transcribe(self, uid, transcript, audio, async):
99117 status ['duration' ] = wav_obj .getnframes () / float (wav_obj .getframerate ())
100118 status ['status' ] = 'TRANSCRIBING'
101119
120+ def on_progress (p ):
121+ for k ,v in p .items ():
122+ status [k ] = v
123+
102124 if len (transcript .strip ()) > 0 :
103125 ms = metasentence .MetaSentence (transcript , self .vocab )
104126 ks = ms .get_kaldi_sequence ()
105127 gen_hclg_filename = language_model .make_bigram_language_model (ks , proto_langdir )
106- else :
107- # TODO: We shouldn't load full language models every time;
108- # these should stay in-memory.
109- gen_hclg_filename = get_resource ('data/graph/HCLG.fst' )
110- if not os .path .exists (gen_hclg_filename ):
111- status ["status" ] = "ERROR"
112- status ["error" ] = 'No transcript provided'
113- return
114-
115- kaldi_queue = Queue ()
116- for i in range (self .nthreads ):
117- kaldi_queue .put (standard_kaldi .Kaldi (
118- get_resource ('data/nnet_a_gpu_online' ),
119- gen_hclg_filename ,
120- proto_langdir )
121- )
122128
123- def on_progress (p ):
124- for k ,v in p .items ():
125- status [k ] = v
129+ kaldi_queue = Queue ()
130+ for i in range (self .nthreads ):
131+ kaldi_queue .put (standard_kaldi .Kaldi (
132+ get_resource ('data/nnet_a_gpu_online' ),
133+ gen_hclg_filename ,
134+ proto_langdir )
135+ )
126136
127- mtt = MultiThreadedTranscriber (kaldi_queue , nthreads = self .nthreads )
128- words = mtt .transcribe (wavfile , progress_cb = on_progress )
137+ mtt = MultiThreadedTranscriber (kaldi_queue , nthreads = self .nthreads )
138+ elif hasattr (self , 'full_transcriber' ):
139+ mtt = self .full_transcriber
140+ else :
141+ status ['status' ] = 'ERROR'
142+ status ['error' ] = 'No transcript provided and no language model for full transcription'
143+ return
129144
130- # Clear queue
131- for i in range (self .nthreads ):
132- k = kaldi_queue .get ()
133- k .stop ()
145+ words = mtt .transcribe (wavfile , progress_cb = on_progress )
134146
135147 output = {}
136148 if len (transcript .strip ()) > 0 :
149+ # Clear queue (would this be gc'ed?)
150+ for i in range (self .nthreads ):
151+ k = kaldi_queue .get ()
152+ k .stop ()
153+
137154 # Align words
138155 output ['words' ] = diff_align .align (words , ms )
139156 output ['transcript' ] = transcript
140157
141158 # Perform a second-pass with unaligned words
142159 logging .info ("%d unaligned words (of %d)" % (len ([X for X in output ['words' ] if X .get ("case" ) == "not-found-in-audio" ]), len (output ['words' ])))
143160
144- to_realign = []
145- last_aligned_word = None
146- cur_unaligned_words = []
147-
148- for wd_idx ,wd in enumerate (output ['words' ]):
149- if wd ['case' ] == 'not-found-in-audio' :
150- cur_unaligned_words .append (wd )
151- elif wd ['case' ] == 'success' :
152- if len (cur_unaligned_words ) > 0 :
153- to_realign .append ({
154- "start" : last_aligned_word ,
155- "end" : wd ,
156- "words" : cur_unaligned_words })
157- cur_unaligned_words = []
158-
159- last_aligned_word = wd
160-
161- if len (cur_unaligned_words ) > 0 :
162- to_realign .append ({
163- "start" : last_aligned_word ,
164- "end" : None ,
165- "words" : cur_unaligned_words })
166-
167- realignments = []
168-
169- def realign (chunk ):
170- start_t = (chunk ["start" ] or {"end" : 0 })["end" ]
171- end_t = (chunk ["end" ] or {"start" : status ["duration" ]})["start" ]
172- duration = end_t - start_t
173- if duration < 0.01 or duration > 60 :
174- logging .info ("cannot realign %d words with duration %f" % (len (chunk ['words' ]), duration ))
175- return
176-
177- # Create a language model
178- offset_offset = chunk ['words' ][0 ]['startOffset' ]
179- chunk_len = chunk ['words' ][- 1 ]['endOffset' ] - offset_offset
180- chunk_transcript = ms .raw_sentence [offset_offset :offset_offset + chunk_len ].encode ("utf-8" )
181- chunk_ms = metasentence .MetaSentence (chunk_transcript , self .vocab )
182- chunk_ks = chunk_ms .get_kaldi_sequence ()
183-
184- chunk_gen_hclg_filename = language_model .make_bigram_language_model (chunk_ks , proto_langdir )
185-
186- k = standard_kaldi .Kaldi (
187- get_resource ('data/nnet_a_gpu_online' ),
188- chunk_gen_hclg_filename ,
189- proto_langdir )
190-
191- wav_obj = wave .open (wavfile , 'r' )
192- wav_obj .setpos (int (start_t * wav_obj .getframerate ()))
193- buf = wav_obj .readframes (int (duration * wav_obj .getframerate ()))
194-
195- k .push_chunk (buf )
196- ret = k .get_final ()
197- k .stop ()
161+ status ['status' ] = 'ALIGNING'
198162
199- word_alignment = diff_align .align (ret , chunk_ms )
200-
201- # Adjust startOffset, endOffset, and timing to match originals
202- for wd in word_alignment :
203- if wd .get ("end" ):
204- # Apply timing offset
205- wd ['start' ] += start_t
206- wd ['end' ] += start_t
207-
208- if wd .get ("endOffset" ):
209- wd ['startOffset' ] += offset_offset
210- wd ['endOffset' ] += offset_offset
211-
212- # "chunk" should be replaced by "words"
213- realignments .append ({"chunk" : chunk , "words" : word_alignment })
214-
215- pool = Pool (self .nthreads )
216- pool .map (realign , to_realign )
217- pool .close ()
218-
219- # Sub in the replacements
220- o_words = output ['words' ]
221- for ret in realignments :
222- st_idx = o_words .index (ret ["chunk" ]["words" ][0 ])
223- end_idx = o_words .index (ret ["chunk" ]["words" ][- 1 ])+ 1
224- logging .debug ('splice in: "%s' % (str (ret ["words" ])))
225- logging .debug ('splice out: "%s' % (str (o_words [st_idx :end_idx ])))
226- o_words = o_words [:st_idx ] + ret ["words" ] + o_words [end_idx :]
227-
228- output ['words' ] = o_words
163+ output ['words' ] = multipass .realign (wavfile , output ['words' ], ms , nthreads = self .nthreads , progress_cb = on_progress )
229164
230165 logging .info ("after 2nd pass: %d unaligned words (of %d)" % (len ([X for X in output ['words' ] if X .get ("case" ) == "not-found-in-audio" ]), len (output ['words' ])))
231166
@@ -361,7 +296,7 @@ def make_transcription_alignment(trans):
361296 trans ["words" ] = words
362297 return trans
363298
364- def serve (port = 8765 , interface = '0.0.0.0' , installSignalHandlers = 0 , nthreads = 4 , data_dir = get_datadir ('webdata' )):
299+ def serve (port = 8765 , interface = '0.0.0.0' , installSignalHandlers = 0 , nthreads = 4 , ntranscriptionthreads = 2 , data_dir = get_datadir ('webdata' )):
365300 logging .info ("SERVE %d, %s, %d" , port , interface , installSignalHandlers )
366301
367302 if not os .path .exists (data_dir ):
@@ -377,7 +312,7 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
377312 f .putChild ('status.html' , File (get_resource ('www/status.html' )))
378313 f .putChild ('preloader.gif' , File (get_resource ('www/preloader.gif' )))
379314
380- trans = Transcriber (data_dir , nthreads = nthreads )
315+ trans = Transcriber (data_dir , nthreads = nthreads , ntranscriptionthreads = ntranscriptionthreads )
381316 trans_ctrl = TranscriptionsController (trans )
382317 f .putChild ('transcriptions' , trans_ctrl )
383318
@@ -403,6 +338,8 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
403338 help = 'port number to run http server on' )
404339 parser .add_argument ('--nthreads' , default = multiprocessing .cpu_count (), type = int ,
405340 help = 'number of alignment threads' )
341+ parser .add_argument ('--ntranscriptionthreads' , default = 2 , type = int ,
342+ help = 'number of full-transcription threads (memory intensive)' )
406343 parser .add_argument ('--log' , default = "INFO" ,
407344 help = 'the log level (DEBUG, INFO, WARNING, ERROR, or CRITICAL)' )
408345
@@ -414,4 +351,4 @@ def serve(port=8765, interface='0.0.0.0', installSignalHandlers=0, nthreads=4, d
414351 logging .info ('gentle %s' % (gentle .__version__ ))
415352 logging .info ('listening at %s:%d\n ' % (args .host , args .port ))
416353
417- serve (args .port , args .host , nthreads = args .nthreads , installSignalHandlers = 1 )
354+ serve (args .port , args .host , nthreads = args .nthreads , ntranscriptionthreads = args . ntranscriptionthreads , installSignalHandlers = 1 )
0 commit comments