@@ -5761,6 +5761,12 @@ def Params(cls):
57615761 'Weight initialization for up and down projections. Only used for '
57625762 'weights, not biases. If None, uses default weight init, which is '
57635763 'typically Xavier with scale of 1.0.' )
5764+ p .Define (
5765+ 'residual_weight' , 1.0 ,
5766+ 'Residual weight (scaling factor) applied to the output of the adapter '
5767+ 'as suggested in https://arxiv.org/pdf/2110.04366.pdf' )
5768+ p .Define ('apply_residual' , True ,
5769+ 'Whether to add the adapter output with inputs.' )
57645770 return p
57655771
57665772
@@ -5795,10 +5801,11 @@ def __init__(self, params):
57955801 self .CreateChild ('down_proj_b' , down_proj_b_params )
57965802 self .CreateChild ('up_proj_w' , up_proj_w_params )
57975803 self .CreateChild ('up_proj_b' , up_proj_b_params )
5798- params = p .layer_norm_tpl .Copy ()
5799- params .name = 'adapter_ln'
5800- params .input_dim = p .input_dim
5801- self .CreateChild ('layer_norm' , params )
5804+ if p .layer_norm_tpl is not None :
5805+ params = p .layer_norm_tpl .Copy ()
5806+ params .name = 'adapter_ln'
5807+ params .input_dim = p .input_dim
5808+ self .CreateChild ('layer_norm' , params )
58025809
58035810 def FProp (self , theta , inputs , tasks ):
58045811 """Fprop for multitask adapter.
@@ -5873,7 +5880,10 @@ def FProp(self, theta, inputs, tasks):
58735880 self .up_proj_b .EmbLookup (theta .up_proj_b , tasks ), time_index )
58745881
58755882 # Layer norm -> down-projection -> non-linearity -> up-projection
5876- norm_inputs = self .layer_norm .FProp (theta .layer_norm , inputs )
5883+ if p .layer_norm_tpl is not None :
5884+ norm_inputs = self .layer_norm .FProp (theta .layer_norm , inputs )
5885+ else :
5886+ norm_inputs = inputs
58775887 # If per_timestep_task, t = 1, b = time * batch.
58785888 # Otherwise, t = time, b = batch.
58795889 if p .data_format == 'TBC' :
@@ -5886,8 +5896,8 @@ def FProp(self, theta, inputs, tasks):
58865896 up_projected = tf .einsum ('tbk,bkh->tbh' , down_projected , up_weights )
58875897 else :
58885898 up_projected = tf .einsum ('btk,bkh->bth' , down_projected , up_weights )
5889- up_projected += up_biases
5890- output = inputs + up_projected
5899+ up_projected = ( up_projected + up_biases ) * p . residual_weight
5900+ output = inputs + up_projected if p . apply_residual else up_projected
58915901
58925902 # Unflatten output:
58935903 # for 'TBC': [1, time * batch, hidden] -> [time, batch, hidden]
@@ -5916,7 +5926,10 @@ def __init__(self, params):
59165926 assert p .data_format == 'BTC'
59175927 params = p .layer_norm_tpl .Copy ()
59185928 params .input_dim = p .input_dim
5919- self .CreateChild ('layer_norm' , params )
5929+ if p .layer_norm_tpl is not None :
5930+ params = p .layer_norm_tpl .Copy ()
5931+ params .input_dim = p .input_dim
5932+ self .CreateChild ('layer_norm' , params )
59205933
59215934 def _CreateLayerVariables (self ):
59225935 super ()._CreateLayerVariables ()
@@ -6025,17 +6038,21 @@ def FProp(self, theta, inputs, tasks):
60256038 theta .up_b )[b_broadcaster ]
60266039
60276040 # Layer norm -> down-projection -> non-linearity -> up-projection
6028- with tf .name_scope ('layer_norm_feed' ):
6029- norm_inputs = self .layer_norm .FProp (theta .layer_norm , inputs )
6041+ if p .layer_norm_tpl is not None :
6042+ with tf .name_scope ('layer_norm_feed' ):
6043+ norm_inputs = self .layer_norm .FProp (theta .layer_norm , inputs )
6044+ else :
6045+ norm_inputs = inputs
60306046 # [batch, time, bottleneck_dim].
60316047 down_projected = tf .einsum (f'bti,b{ t } in->btn' , norm_inputs , down_w ) + down_b
60326048 # ReLU.
60336049 down_projected = tf .nn .relu (down_projected )
60346050 # [batch, time, input_dim].
60356051 up_projected = tf .einsum (f'btn,b{ t } ni->bti' , down_projected , up_w ) + up_b
6052+ up_projected *= p .residual_weight
60366053 # Residual.
6037- res = inputs + up_projected
6038- return res
6054+ output = inputs + up_projected if p . apply_residual else up_projected
6055+ return output
60396056
60406057
60416058class CCTGatingNetwork (quant_utils .QuantizableLayer ):
0 commit comments