Skip to content

Commit 33bd750

Browse files
ds-hwangcopybara-github
authored andcommitted
Clean up ConformerLayer.adapter_tpl.
Remove unnecessary restrictions. PiperOrigin-RevId: 492467060
1 parent 969916d commit 33bd750

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

lingvo/core/conformer_layer.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,10 @@ def Params(cls):
512512
'If True, will ignore `fflayer_end_tpl`, and will make the fflayer_end '
513513
'layer as a weight-shared copy of the fflayer_start layer.')
514514
p.Define('final_ln_tpl', layers.LayerNorm.Params(), 'Final layer norm.')
515-
# If adapter_tpl is set, layer_out = adapter(conformer(layer_in))
516-
# The adapter must
517-
# 1. have instance method FProp(self, theta, in_nmap) -> out_nmap, where
518-
# {in,out}_nmap must have 'features' and 'paddings' and $adapter_p.task_ids
519-
# fields.
520-
# 2. have class method SetInputDim(cls, p, input_dim)
521-
p.Define('adapter_tpl', None, 'If set, runs an adapter layer in the end.')
515+
p.Define(
516+
'adapter_tpl', None, 'If set, runs an adapter layer in the end. '
517+
'This is expected to be an OnlineLayer, whose FProp signature is '
518+
'FProp(self, theta, in_nmap, state0=None) -> (out_nmap, state1).')
522519
# https://b/167460492#comment16
523520
p.Define(
524521
'remat', False, 'If to rematerialize the layer. If true, '
@@ -831,7 +828,7 @@ def __init__(self, params):
831828
self.CreateChild('final_ln', ln_p)
832829

833830
if p.adapter_tpl:
834-
p.adapter_tpl.cls.SetInputDim(p.adapter_tpl, p.input_dim)
831+
p.adapter_tpl.cls.SetNumInputNodes(p.adapter_tpl, p.input_dim)
835832
self.CreateChild('adapter', p.adapter_tpl)
836833

837834
# lconv and fflayer_start have the special treatment, which can be absent,
@@ -968,7 +965,7 @@ def _FProp(self, theta, in_nmap):
968965
if p.adapter_tpl:
969966
adapter_in_map = in_nmap.DeepCopy()
970967
adapter_in_map.features, adapter_in_map.padding = features, paddings
971-
adapter_out_nmap = self.adapter.FProp(theta.adapter, adapter_in_map)
968+
adapter_out_nmap, _ = self.adapter.FProp(theta.adapter, adapter_in_map)
972969
features, paddings = adapter_out_nmap.features, adapter_out_nmap.paddings
973970

974971
features, paddings = self._CastToFPropDtype((features, paddings))
@@ -1068,7 +1065,7 @@ def StreamStep(self, theta, in_nmap, state0):
10681065
if p.adapter_tpl:
10691066
adapter_in_map = in_nmap.DeepCopy()
10701067
adapter_in_map.features, adapter_in_map.padding = outputs, paddings
1071-
adapter_out_nmap = self.adapter.FProp(theta.adapter, adapter_in_map)
1068+
adapter_out_nmap, _ = self.adapter.FProp(theta.adapter, adapter_in_map)
10721069
outputs, paddings = adapter_out_nmap.features, adapter_out_nmap.paddings
10731070

10741071
state1 = py_utils.NestedMap(

0 commit comments

Comments
 (0)