Skip to content

Commit be05a35

Browse files
Qiujia Licopybara-github
authored andcommitted
Add various options for adapters.
* scaling factor * option for not applying residual connection * dropout after bottleneck layer * option for not applying layernorm PiperOrigin-RevId: 491746421
1 parent 786e981 commit be05a35

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

lingvo/core/layers.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5761,6 +5761,12 @@ def Params(cls):
57615761
'Weight initialization for up and down projections. Only used for '
57625762
'weights, not biases. If None, uses default weight init, which is '
57635763
'typically Xavier with scale of 1.0.')
5764+
p.Define(
5765+
'residual_weight', 1.0,
5766+
'Residual weight (scaling factor) applied to the output of the adapter '
5767+
'as suggested in https://arxiv.org/pdf/2110.04366.pdf')
5768+
p.Define('apply_residual', True,
5769+
'Whether to add the adapter output with inputs.')
57645770
return p
57655771

57665772

@@ -5795,10 +5801,11 @@ def __init__(self, params):
57955801
self.CreateChild('down_proj_b', down_proj_b_params)
57965802
self.CreateChild('up_proj_w', up_proj_w_params)
57975803
self.CreateChild('up_proj_b', up_proj_b_params)
5798-
params = p.layer_norm_tpl.Copy()
5799-
params.name = 'adapter_ln'
5800-
params.input_dim = p.input_dim
5801-
self.CreateChild('layer_norm', params)
5804+
if p.layer_norm_tpl is not None:
5805+
params = p.layer_norm_tpl.Copy()
5806+
params.name = 'adapter_ln'
5807+
params.input_dim = p.input_dim
5808+
self.CreateChild('layer_norm', params)
58025809

58035810
def FProp(self, theta, inputs, tasks):
58045811
"""Fprop for multitask adapter.
@@ -5873,7 +5880,10 @@ def FProp(self, theta, inputs, tasks):
58735880
self.up_proj_b.EmbLookup(theta.up_proj_b, tasks), time_index)
58745881

58755882
# Layer norm -> down-projection -> non-linearity -> up-projection
5876-
norm_inputs = self.layer_norm.FProp(theta.layer_norm, inputs)
5883+
if p.layer_norm_tpl is not None:
5884+
norm_inputs = self.layer_norm.FProp(theta.layer_norm, inputs)
5885+
else:
5886+
norm_inputs = inputs
58775887
# If per_timestep_task, t = 1, b = time * batch.
58785888
# Otherwise, t = time, b = batch.
58795889
if p.data_format == 'TBC':
@@ -5886,8 +5896,8 @@ def FProp(self, theta, inputs, tasks):
58865896
up_projected = tf.einsum('tbk,bkh->tbh', down_projected, up_weights)
58875897
else:
58885898
up_projected = tf.einsum('btk,bkh->bth', down_projected, up_weights)
5889-
up_projected += up_biases
5890-
output = inputs + up_projected
5899+
up_projected = (up_projected + up_biases) * p.residual_weight
5900+
output = inputs + up_projected if p.apply_residual else up_projected
58915901

58925902
# Unflatten output:
58935903
# for 'TBC': [1, time * batch, hidden] -> [time, batch, hidden]
@@ -5916,7 +5926,10 @@ def __init__(self, params):
59165926
assert p.data_format == 'BTC'
59175927
params = p.layer_norm_tpl.Copy()
59185928
params.input_dim = p.input_dim
5919-
self.CreateChild('layer_norm', params)
5929+
if p.layer_norm_tpl is not None:
5930+
params = p.layer_norm_tpl.Copy()
5931+
params.input_dim = p.input_dim
5932+
self.CreateChild('layer_norm', params)
59205933

59215934
def _CreateLayerVariables(self):
59225935
super()._CreateLayerVariables()
@@ -6025,17 +6038,21 @@ def FProp(self, theta, inputs, tasks):
60256038
theta.up_b)[b_broadcaster]
60266039

60276040
# Layer norm -> down-projection -> non-linearity -> up-projection
6028-
with tf.name_scope('layer_norm_feed'):
6029-
norm_inputs = self.layer_norm.FProp(theta.layer_norm, inputs)
6041+
if p.layer_norm_tpl is not None:
6042+
with tf.name_scope('layer_norm_feed'):
6043+
norm_inputs = self.layer_norm.FProp(theta.layer_norm, inputs)
6044+
else:
6045+
norm_inputs = inputs
60306046
# [batch, time, bottleneck_dim].
60316047
down_projected = tf.einsum(f'bti,b{t}in->btn', norm_inputs, down_w) + down_b
60326048
# ReLU.
60336049
down_projected = tf.nn.relu(down_projected)
60346050
# [batch, time, input_dim].
60356051
up_projected = tf.einsum(f'btn,b{t}ni->bti', down_projected, up_w) + up_b
6052+
up_projected *= p.residual_weight
60366053
# Residual.
6037-
res = inputs + up_projected
6038-
return res
6054+
output = inputs + up_projected if p.apply_residual else up_projected
6055+
return output
60396056

60406057

60416058
class CCTGatingNetwork(quant_utils.QuantizableLayer):

0 commit comments

Comments
 (0)