@@ -376,6 +376,7 @@ def __init__(
376376 length_normalization = True ,
377377 mask_layer_name = None ,
378378 eos_cond_layer_name = None ,
379+ handle_eos_for_ilm = False ,
379380 renorm_wo_eos = False ,
380381 ):
381382 self .asr_decoder = copy .deepcopy (asr_decoder )
@@ -388,10 +389,41 @@ def __init__(
388389 self .length_normalization = length_normalization
389390 self .mask_layer_name = mask_layer_name
390391 self .eos_cond_layer_name = eos_cond_layer_name
392+ self .handle_eos_for_ilm = handle_eos_for_ilm
391393 self .renorm_wo_eos = renorm_wo_eos
392394
393395 self .network = None
394396
397+ def _handle_EOS (self , lm_net_out , lm_output_prob , prefix = "" ):
398+ # so this means that eos prob is only used when the condition is true
399+ lm_output_prob_wo_eos_ = lm_net_out .add_slice_layer (
400+ f"{ prefix } lm_output_prob_wo_eos_" , lm_output_prob , axis = "F" , slice_start = 1
401+ ) # [B,V-1]
402+ lm_output_prob_eos = lm_net_out .add_slice_layer (
403+ f"{ prefix } lm_output_prob_eos" , lm_output_prob , axis = "F" , slice_start = 0 , slice_end = 1
404+ ) # [B,1]
405+ if self .renorm_wo_eos :
406+ lm_output_prob_eos_renorm = lm_net_out .add_activation_layer (
407+ "lm_output_prob_eos_renorm" , lm_output_prob_eos , activation = "softmax"
408+ ) # [B,V-1]
409+ else :
410+ lm_output_prob_eos_renorm = lm_output_prob_eos # [B,V-1]
411+ prob_1_const = lm_net_out .add_eval_layer (
412+ f"{ prefix } prob_1_const" , lm_output_prob_eos , eval = "tf.ones_like(source(0))"
413+ ) # convert to ones
414+ lm_output_prob_wo_eos = lm_net_out .add_generic_layer (
415+ f"{ prefix } lm_output_prob_wo_eos" ,
416+ cls = "concat" ,
417+ source = [(prob_1_const , "F" ), (lm_output_prob_eos_renorm , "F" )],
418+ )
419+ lm_output_prob_cond = lm_net_out .add_switch_layer (
420+ f"{ prefix } lm_output_prob_cond" ,
421+ condition = self .eos_cond_layer_name ,
422+ true_from = lm_output_prob ,
423+ false_from = lm_output_prob_wo_eos ,
424+ )
425+ return lm_output_prob_cond
426+
395427 def _create_external_lm_net (self ) -> dict :
396428 lm_net_out = ReturnnNetwork ()
397429
@@ -438,27 +470,7 @@ def _create_external_lm_net(self) -> dict:
438470 )
439471
440472 if self .eos_cond_layer_name :
441- # so this means that eos prob is only used when the condition is true
442- lm_output_prob_wo_eos_ = lm_net_out .add_slice_layer (
443- "lm_output_prob_wo_eos_" , lm_output_prob , axis = "F" , slice_start = 1
444- ) # [B,V-1]
445- lm_output_prob_eos = lm_net_out .add_slice_layer (
446- "lm_output_prob_eos" , lm_output_prob , axis = "F" , slice_start = 0 , slice_end = 1
447- ) # [B,1]
448- prob_1_const = lm_net_out .add_eval_layer (
449- "prob_1_const" , lm_output_prob_eos , eval = "tf.ones_like(source(0))"
450- ) # convert to ones
451- lm_output_prob_wo_eos = lm_net_out .add_generic_layer (
452- "lm_output_prob_wo_eos" ,
453- cls = "concat" ,
454- source = [(prob_1_const , "F" ), (lm_output_prob_wo_eos_ , "F" )],
455- )
456- lm_output_prob = lm_net_out .add_switch_layer (
457- "lm_output_prob_cond" ,
458- condition = self .eos_cond_layer_name ,
459- true_from = lm_output_prob ,
460- false_from = lm_output_prob_wo_eos ,
461- )
473+ lm_output_prob = self ._handle_EOS (lm_net_out , lm_output_prob )
462474
463475 fusion_str = "safe_log(source(0)) + {} * safe_log(source(1))" .format (ext_lm_scale ) # shallow fusion
464476 fusion_source = [self .am_output_prob , lm_output_prob ]
@@ -475,8 +487,15 @@ def _create_external_lm_net(self) -> dict:
475487 raise ValueError ("dec type: {} is not valid" .format (self .dec_type ))
476488
477489 ilm_decoder .create_network () # add ILM
490+
491+ if self .handle_eos_for_ilm :
492+ assert self .eos_cond_layer_name
493+ ilm_output_prob = self ._handle_EOS (lm_net_out , ilm_decoder .output_prob_name , prefix = "ilm_" )
494+ else :
495+ ilm_output_prob = ilm_decoder .output_prob_name
496+
478497 fusion_str += " - {} * safe_log(source(2))" .format (self .prior_lm_opts ["scale" ])
479- fusion_source += [ilm_decoder . output_prob_name ]
498+ fusion_source += [ilm_output_prob ]
480499
481500 if self .ext_lm_opts .get ("local_norm" , False ):
482501 fusion_str = f"{ fusion_str } - tf.math.reduce_logsumexp({ fusion_str } , axis=-1, keepdims=True)"
0 commit comments