@@ -366,7 +366,18 @@ class ExternalLMDecoder:
366366 Integrates an external LM decoder into an ASR decoder
367367 """
368368
369- def __init__ (self , asr_decoder , ext_lm_opts , beam_size , dec_type , prior_lm_opts = None , length_normalization = True ):
369+ def __init__ (
370+ self ,
371+ asr_decoder ,
372+ ext_lm_opts ,
373+ beam_size ,
374+ dec_type ,
375+ prior_lm_opts = None ,
376+ length_normalization = True ,
377+ mask_layer_name = None ,
378+ eos_cond_layer_name = None ,
379+ renorm_wo_eos = False ,
380+ ):
370381 self .asr_decoder = copy .deepcopy (asr_decoder )
371382 self .am_output_prob = self .asr_decoder .output_prob
372383 self .target = self .asr_decoder .target
@@ -375,6 +386,9 @@ def __init__(self, asr_decoder, ext_lm_opts, beam_size, dec_type, prior_lm_opts=
375386 self .prior_lm_opts = prior_lm_opts
376387 self .dec_type = dec_type
377388 self .length_normalization = length_normalization
389+ self .mask_layer_name = mask_layer_name
390+ self .eos_cond_layer_name = eos_cond_layer_name
391+ self .renorm_wo_eos = renorm_wo_eos
378392
379393 self .network = None
380394
@@ -402,11 +416,48 @@ def _create_external_lm_net(self) -> dict:
402416 ), "load_on_init opts or lm_model are missing for loading subnet."
403417 assert "filename" in self .ext_lm_opts ["load_on_init_opts" ], "Checkpoint missing for loading subnet."
404418 load_on_init = self .ext_lm_opts ["load_on_init_opts" ]
405- lm_net_out .add_subnetwork (
406- "lm_output" , "prev:output" , subnetwork_net = ext_lm_subnet , load_on_init = load_on_init
407- )
419+
420+ if self .mask_layer_name :
421+ lm_output = lm_net_out .add_masked_computation_layer (
422+ "lm_output_masked" ,
423+ "prev:output" ,
424+ mask = self .mask_layer_name ,
425+ unit = {
426+ "class" : "subnetwork" ,
427+ "from" : "data" ,
428+ "subnetwork" : ext_lm_subnet ,
429+ "load_on_init" : load_on_init ,
430+ },
431+ )
432+ else :
433+ lm_output = lm_net_out .add_subnetwork (
434+ "lm_output" , "prev:output" , subnetwork_net = ext_lm_subnet , load_on_init = load_on_init
435+ )
408436 lm_output_prob = lm_net_out .add_activation_layer (
409- "lm_output_prob" , "lm_output" , activation = "softmax" , target = self .target
437+ "lm_output_prob" , lm_output , activation = "softmax" , target = self .target
438+ )
439+
440+ if self .eos_cond_layer_name :
441+ # so this means that eos prob is only used when the condition is true
442+ lm_output_prob_wo_eos_ = lm_net_out .add_slice_layer (
443+ "lm_output_prob_wo_eos_" , lm_output_prob , axis = "F" , slice_start = 1
444+ ) # [B,V-1]
445+ lm_output_prob_eos = lm_net_out .add_slice_layer (
446+ "lm_output_prob_eos" , lm_output_prob , axis = "F" , slice_start = 0 , slice_end = 1
447+ ) # [B,1]
448+ prob_1_const = lm_net_out .add_eval_layer (
449+ "prob_1_const" , lm_output_prob_eos , eval = "tf.ones_like(source(0))"
450+ ) # convert to ones
451+ lm_output_prob_wo_eos = lm_net_out .add_generic_layer (
452+ "lm_output_prob_wo_eos" ,
453+ cls = "concat" ,
454+ source = [(prob_1_const , "F" ), (lm_output_prob_wo_eos_ , "F" )],
455+ )
456+ lm_output_prob = lm_net_out .add_switch_layer (
457+ "lm_output_prob_cond" ,
458+ condition = self .eos_cond_layer_name ,
459+ true_from = lm_output_prob ,
460+ false_from = lm_output_prob_wo_eos ,
410461 )
411462
412463 fusion_str = "safe_log(source(0)) + {} * safe_log(source(1))" .format (ext_lm_scale ) # shallow fusion
@@ -416,7 +467,6 @@ def _create_external_lm_net(self) -> dict:
416467 fusion_str = "{} * " .format (self .ext_lm_opts ["am_scale" ]) + fusion_str # add am_scale for local fusion
417468
418469 if self .prior_lm_opts :
419-
420470 if self .dec_type == "lstm" :
421471 ilm_decoder = LSTMILMDecoder (self .asr_decoder , self .prior_lm_opts )
422472 elif self .dec_type == "transformer" :
0 commit comments