Skip to content

Commit 4ad7565

Browse files
committed
add correct EOS handling for LM combination
1 parent a5fdeca commit 4ad7565

File tree

1 file changed

+56
-6
lines changed

1 file changed

+56
-6
lines changed

users/zeineldeen/models/lm/external_lm_decoder.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,18 @@ class ExternalLMDecoder:
366366
Integrates an external LM decoder into an ASR decoder
367367
"""
368368

369-
def __init__(self, asr_decoder, ext_lm_opts, beam_size, dec_type, prior_lm_opts=None, length_normalization=True):
369+
def __init__(
370+
self,
371+
asr_decoder,
372+
ext_lm_opts,
373+
beam_size,
374+
dec_type,
375+
prior_lm_opts=None,
376+
length_normalization=True,
377+
mask_layer_name=None,
378+
eos_cond_layer_name=None,
379+
renorm_wo_eos=False,
380+
):
370381
self.asr_decoder = copy.deepcopy(asr_decoder)
371382
self.am_output_prob = self.asr_decoder.output_prob
372383
self.target = self.asr_decoder.target
@@ -375,6 +386,9 @@ def __init__(self, asr_decoder, ext_lm_opts, beam_size, dec_type, prior_lm_opts=
375386
self.prior_lm_opts = prior_lm_opts
376387
self.dec_type = dec_type
377388
self.length_normalization = length_normalization
389+
self.mask_layer_name = mask_layer_name
390+
self.eos_cond_layer_name = eos_cond_layer_name
391+
self.renorm_wo_eos = renorm_wo_eos
378392

379393
self.network = None
380394

@@ -402,11 +416,48 @@ def _create_external_lm_net(self) -> dict:
402416
), "load_on_init opts or lm_model are missing for loading subnet."
403417
assert "filename" in self.ext_lm_opts["load_on_init_opts"], "Checkpoint missing for loading subnet."
404418
load_on_init = self.ext_lm_opts["load_on_init_opts"]
405-
lm_net_out.add_subnetwork(
406-
"lm_output", "prev:output", subnetwork_net=ext_lm_subnet, load_on_init=load_on_init
407-
)
419+
420+
if self.mask_layer_name:
421+
lm_output = lm_net_out.add_masked_computation_layer(
422+
"lm_output_masked",
423+
"prev:output",
424+
mask=self.mask_layer_name,
425+
unit={
426+
"class": "subnetwork",
427+
"from": "data",
428+
"subnetwork": ext_lm_subnet,
429+
"load_on_init": load_on_init,
430+
},
431+
)
432+
else:
433+
lm_output = lm_net_out.add_subnetwork(
434+
"lm_output", "prev:output", subnetwork_net=ext_lm_subnet, load_on_init=load_on_init
435+
)
408436
lm_output_prob = lm_net_out.add_activation_layer(
409-
"lm_output_prob", "lm_output", activation="softmax", target=self.target
437+
"lm_output_prob", lm_output, activation="softmax", target=self.target
438+
)
439+
440+
if self.eos_cond_layer_name:
441+
# so this means that eos prob is only used when the condition is true
442+
lm_output_prob_wo_eos_ = lm_net_out.add_slice_layer(
443+
"lm_output_prob_wo_eos_", lm_output_prob, axis="F", slice_start=1
444+
) # [B,V-1]
445+
lm_output_prob_eos = lm_net_out.add_slice_layer(
446+
"lm_output_prob_eos", lm_output_prob, axis="F", slice_start=0, slice_end=1
447+
) # [B,1]
448+
prob_1_const = lm_net_out.add_eval_layer(
449+
"prob_1_const", lm_output_prob_eos, eval="tf.ones_like(source(0))"
450+
) # convert to ones
451+
lm_output_prob_wo_eos = lm_net_out.add_generic_layer(
452+
"lm_output_prob_wo_eos",
453+
cls="concat",
454+
source=[(prob_1_const, "F"), (lm_output_prob_wo_eos_, "F")],
455+
)
456+
lm_output_prob = lm_net_out.add_switch_layer(
457+
"lm_output_prob_cond",
458+
condition=self.eos_cond_layer_name,
459+
true_from=lm_output_prob,
460+
false_from=lm_output_prob_wo_eos,
410461
)
411462

412463
fusion_str = "safe_log(source(0)) + {} * safe_log(source(1))".format(ext_lm_scale) # shallow fusion
@@ -416,7 +467,6 @@ def _create_external_lm_net(self) -> dict:
416467
fusion_str = "{} * ".format(self.ext_lm_opts["am_scale"]) + fusion_str # add am_scale for local fusion
417468

418469
if self.prior_lm_opts:
419-
420470
if self.dec_type == "lstm":
421471
ilm_decoder = LSTMILMDecoder(self.asr_decoder, self.prior_lm_opts)
422472
elif self.dec_type == "transformer":

0 commit comments

Comments
 (0)