Skip to content

Commit 786e981

Browse files
ds-hwangcopybara-github
authored andcommitted
Fix ConformerLayer's adapter_tpl for StreamStep
ConformerLayer.StreamStep doesn't respect p.adapter_tpl, while Fprop respects it. PiperOrigin-RevId: 491695042
1 parent 974f60f commit 786e981

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

lingvo/core/conformer_layer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,13 +1013,13 @@ def zero_state(self, batch_size):
10131013
else:
10141014
return py_utils.NestedMap()
10151015

1016-
def StreamStep(self, theta, inputs, paddings, state0):
1016+
def StreamStep(self, theta, in_nmap, state0):
10171017
"""Streams t steps.
10181018
10191019
Args:
10201020
theta: A NestedMap of read-only layer params.
1021-
inputs: A tensor of shape [b, t, d].
1022-
paddings: A 0/1 valued tensor of shape [b, t].
1021+
in_nmap: A NestedMap including features tensor of shape [b, t, d], and
1022+
paddings tensor of shape [b, t].
10231023
state0: A NestedMap of tensors of the same struct as returned by
10241024
zero_state().
10251025
@@ -1033,7 +1033,7 @@ def StreamStep(self, theta, inputs, paddings, state0):
10331033
assert not p.remat
10341034

10351035
with tf.name_scope(f'{p.name}/StreamStep'):
1036-
features, aux_loss = inputs, None
1036+
features, paddings, aux_loss = in_nmap.features, in_nmap.paddings, None
10371037

10381038
if self.has_fflayer_start:
10391039
features, paddings, aux_loss = self._MoeOrFFLayer(
@@ -1065,6 +1065,12 @@ def StreamStep(self, theta, inputs, paddings, state0):
10651065
paddings, aux_loss)
10661066
outputs = self.final_ln.FProp(theta.final_ln, features)
10671067

1068+
if p.adapter_tpl:
1069+
adapter_in_map = in_nmap.DeepCopy()
1070+
adapter_in_map.features, adapter_in_map.padding = outputs, paddings
1071+
adapter_out_nmap = self.adapter.FProp(theta.adapter, adapter_in_map)
1072+
outputs, paddings = adapter_out_nmap.features, adapter_out_nmap.paddings
1073+
10681074
state1 = py_utils.NestedMap(
10691075
lconv_state=lconv_state1, atten_state=atten_state1)
10701076
return outputs, paddings, state1

lingvo/core/conformer_layer_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,8 @@ def _FProp(self, layer, inputs, paddings):
626626
py_utils.NestedMap(features=inputs, paddings=paddings))
627627

628628
def _StreamStep(self, layer, step_inputs, step_paddings, state):
629-
return layer.StreamStep(layer.theta, step_inputs, step_paddings, state)
629+
in_nmap = py_utils.NestedMap(features=step_inputs, paddings=step_paddings)
630+
return layer.StreamStep(layer.theta, in_nmap, state)
630631

631632
def _GetFPropOutput(self, fprop_out):
632633
return fprop_out.features, fprop_out.paddings

0 commit comments

Comments
 (0)