Skip to content

Commit dc81229

Browse files
committed
eos renorm + ilm eos handling
1 parent 4ad7565 commit dc81229

File tree

1 file changed

+41
-22
lines changed

1 file changed

+41
-22
lines changed

users/zeineldeen/models/lm/external_lm_decoder.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def __init__(
376376
length_normalization=True,
377377
mask_layer_name=None,
378378
eos_cond_layer_name=None,
379+
handle_eos_for_ilm=False,
379380
renorm_wo_eos=False,
380381
):
381382
self.asr_decoder = copy.deepcopy(asr_decoder)
@@ -388,10 +389,41 @@ def __init__(
388389
self.length_normalization = length_normalization
389390
self.mask_layer_name = mask_layer_name
390391
self.eos_cond_layer_name = eos_cond_layer_name
392+
self.handle_eos_for_ilm = handle_eos_for_ilm
391393
self.renorm_wo_eos = renorm_wo_eos
392394

393395
self.network = None
394396

397+
def _handle_EOS(self, lm_net_out, lm_output_prob, prefix=""):
398+
# so this means that eos prob is only used when the condition is true
399+
lm_output_prob_wo_eos_ = lm_net_out.add_slice_layer(
400+
f"{prefix}lm_output_prob_wo_eos_", lm_output_prob, axis="F", slice_start=1
401+
) # [B,V-1]
402+
lm_output_prob_eos = lm_net_out.add_slice_layer(
403+
f"{prefix}lm_output_prob_eos", lm_output_prob, axis="F", slice_start=0, slice_end=1
404+
) # [B,1]
405+
if self.renorm_wo_eos:
406+
lm_output_prob_eos_renorm = lm_net_out.add_activation_layer(
407+
"lm_output_prob_eos_renorm", lm_output_prob_eos, activation="softmax"
408+
) # [B,V-1]
409+
else:
410+
lm_output_prob_eos_renorm = lm_output_prob_eos # [B,V-1]
411+
prob_1_const = lm_net_out.add_eval_layer(
412+
f"{prefix}prob_1_const", lm_output_prob_eos, eval="tf.ones_like(source(0))"
413+
) # convert to ones
414+
lm_output_prob_wo_eos = lm_net_out.add_generic_layer(
415+
f"{prefix}lm_output_prob_wo_eos",
416+
cls="concat",
417+
source=[(prob_1_const, "F"), (lm_output_prob_eos_renorm, "F")],
418+
)
419+
lm_output_prob_cond = lm_net_out.add_switch_layer(
420+
f"{prefix}lm_output_prob_cond",
421+
condition=self.eos_cond_layer_name,
422+
true_from=lm_output_prob,
423+
false_from=lm_output_prob_wo_eos,
424+
)
425+
return lm_output_prob_cond
426+
395427
def _create_external_lm_net(self) -> dict:
396428
lm_net_out = ReturnnNetwork()
397429

@@ -438,27 +470,7 @@ def _create_external_lm_net(self) -> dict:
438470
)
439471

440472
if self.eos_cond_layer_name:
441-
# so this means that eos prob is only used when the condition is true
442-
lm_output_prob_wo_eos_ = lm_net_out.add_slice_layer(
443-
"lm_output_prob_wo_eos_", lm_output_prob, axis="F", slice_start=1
444-
) # [B,V-1]
445-
lm_output_prob_eos = lm_net_out.add_slice_layer(
446-
"lm_output_prob_eos", lm_output_prob, axis="F", slice_start=0, slice_end=1
447-
) # [B,1]
448-
prob_1_const = lm_net_out.add_eval_layer(
449-
"prob_1_const", lm_output_prob_eos, eval="tf.ones_like(source(0))"
450-
) # convert to ones
451-
lm_output_prob_wo_eos = lm_net_out.add_generic_layer(
452-
"lm_output_prob_wo_eos",
453-
cls="concat",
454-
source=[(prob_1_const, "F"), (lm_output_prob_wo_eos_, "F")],
455-
)
456-
lm_output_prob = lm_net_out.add_switch_layer(
457-
"lm_output_prob_cond",
458-
condition=self.eos_cond_layer_name,
459-
true_from=lm_output_prob,
460-
false_from=lm_output_prob_wo_eos,
461-
)
473+
lm_output_prob = self._handle_EOS(lm_net_out, lm_output_prob)
462474

463475
fusion_str = "safe_log(source(0)) + {} * safe_log(source(1))".format(ext_lm_scale) # shallow fusion
464476
fusion_source = [self.am_output_prob, lm_output_prob]
@@ -475,8 +487,15 @@ def _create_external_lm_net(self) -> dict:
475487
raise ValueError("dec type: {} is not valid".format(self.dec_type))
476488

477489
ilm_decoder.create_network() # add ILM
490+
491+
if self.handle_eos_for_ilm:
492+
assert self.eos_cond_layer_name
493+
ilm_output_prob = self._handle_EOS(lm_net_out, ilm_decoder.output_prob_name, prefix="ilm_")
494+
else:
495+
ilm_output_prob = ilm_decoder.output_prob_name
496+
478497
fusion_str += " - {} * safe_log(source(2))".format(self.prior_lm_opts["scale"])
479-
fusion_source += [ilm_decoder.output_prob_name]
498+
fusion_source += [ilm_output_prob]
480499

481500
if self.ext_lm_opts.get("local_norm", False):
482501
fusion_str = f"{fusion_str} - tf.math.reduce_logsumexp({fusion_str}, axis=-1, keepdims=True)"

0 commit comments

Comments
 (0)