@@ -170,7 +170,7 @@ def FPropMeta(cls, p, inputs, *args):
170170 return py_utils .NestedMap (flops = 0 , out_shapes = (inputs ,))
171171
172172
173- # TODO(yonghui/ jonathanasdf): Remove the forwarded links.
173+ # TODO(yonghui, jonathanasdf): Remove the forwarded links.
174174_ComputeConvOutputShape = conv_layers_with_time_padding .ComputeConvOutputShape
175175_ComputeConvOutputPadding = (
176176 conv_layers_with_time_padding .ComputeConvOutputPadding )
@@ -939,7 +939,7 @@ def __init__(self, params):
939939 act_multiplier = activations .DimMultiplier (p .activation )
940940 self ._internal_output_dim = p .output_dim * act_multiplier
941941 if act_multiplier > 1 :
942- assert not p .affine_last , ( 'Affine last does not support GLU variants.' )
942+ assert not p .affine_last , 'Affine last does not support GLU variants.'
943943 assert not p .use_blocked_matmul , (
944944 'Blocked matmul does not support GLU variants.' )
945945 assert not p .use_block_diagonal_matmul , (
@@ -995,10 +995,8 @@ def _GetBlockedWeightMatrix(self, w):
995995 w = tf .slice (w , [0 , 0 , 0 ], [p .input_dim , w_om , block_dim ])
996996 return w
997997
998- def _GetBlockDiagonalInitScale (self , num_blocks , dense_shape , dtype = None ):
998+ def _GetBlockDiagonalInitScale (self , num_blocks , dense_shape ):
999999 m , n = dense_shape
1000- if not dtype :
1001- dtype = tf .float32
10021000 scale = math .sqrt (6.0 / (m // num_blocks + n // num_blocks ))
10031001 return scale
10041002
@@ -1017,8 +1015,8 @@ def _CreateLayerVariables(self):
10171015 shape = (p .bd_num_blocks , p .input_dim // p .bd_num_blocks ,
10181016 p .output_dim // p .bd_num_blocks ),
10191017 init = py_utils .WeightInit .Xavier (
1020- scale = self ._GetBlockDiagonalInitScale (
1021- p .bd_num_blocks , ( p . input_dim , p .output_dim ), dtype = p . dtype )),
1018+ scale = self ._GetBlockDiagonalInitScale (p . bd_num_blocks , (
1019+ p .input_dim , p .output_dim ))),
10221020 dtype = p .dtype ,
10231021 device_mesh = p .device_mesh ,
10241022 tensor_split_dims_mapping = p .weight_split_dims_mapping ,
@@ -1585,7 +1583,6 @@ def __init__(self, params):
15851583 self .CreateChildren ('fc' , params_fc_layers )
15861584 self .CreateChildren ('dropout' , params_dropout_layers )
15871585
1588- memory_params = None
15891586 if p .memory_augmentation :
15901587 assert p .memory is not None
15911588 memory_params = p .memory .Copy ()
@@ -3907,10 +3904,6 @@ def XentLossFromLogits(self,
39073904 assert logits is not None
39083905 per_example_argmax = py_utils .ArgMax (logits )
39093906
3910- # For compatibility with other softmax implementations.
3911- if (class_weights is not None and
3912- py_utils .GetRank (class_weights ) == py_utils .GetRank (logits )):
3913- class_weights = tf .squeeze (class_weights , - 1 )
39143907 if (class_ids is not None and
39153908 py_utils .GetRank (class_ids ) == py_utils .GetRank (logits )):
39163909 class_ids = tf .squeeze (class_ids , - 1 )
@@ -4641,7 +4634,7 @@ def FPropMeta(cls, p, inputs):
46414634 flops = inputs .num_elements () * 10 , out_shapes = (inputs ,))
46424635
46434636
4644- # TODO(shibow/ wangtao) remove this after b/174094694 is done.
4637+ # TODO(shibow, wangtao) remove this after b/174094694 is done.
46454638class ReshapedLayerNorm (LayerNorm ):
46464639 """Customized LayerNorm with model dim D reshaped as Md."""
46474640
@@ -5273,7 +5266,7 @@ def __init__(self, params):
52735266 if not p .name :
52745267 raise ValueError ('Layer must have a specified name!' )
52755268
5276- assert p .num_sources > 0 , ( 'Must specify num_sources > 0.' )
5269+ assert p .num_sources > 0 , 'Must specify num_sources > 0.'
52775270
52785271 if p .weighted_merger_dropout_prob > 0.0 :
52795272 dropout_tpl = DropoutLayer .Params ()
@@ -6076,11 +6069,10 @@ def FProp(self, theta, inputs, paddings=None):
60766069 def FPropMeta (cls , p , inputs , paddings = None ):
60776070 py_utils .CheckShapes ((inputs ,))
60786071 assert inputs [- 1 ] == p .input_dim
6079- flops = 0
60806072 in_dim = inputs [- 1 ]
60816073 other_dims = inputs .num_elements () / in_dim
60826074 flops = 5 * other_dims * in_dim * p .hidden_layer_dim
6083- flops = 5 * other_dims * p .num_outputs * p .hidden_layer_dim
6075+ flops + = 5 * other_dims * p .num_outputs * p .hidden_layer_dim
60846076 out_shape = tshape .Shape (inputs [:- 1 ] + [symbolic .ToStatic (p .num_outputs )])
60856077 return py_utils .NestedMap (flops = flops , out_shapes = (out_shape ,))
60866078
@@ -6319,24 +6311,21 @@ def FProp(self, theta, inputs):
63196311 bucket_one_hot = tf .one_hot (bucket_ids , num_buckets , axis = - 1 )
63206312 summed_one_hot = tf .math .reduce_mean (bucket_one_hot , axis = - 3 )
63216313
6314+ outputs = 0.0
63226315 if p .rank > 0 :
63236316 u_mat = tf .einsum ('...n,nir->...ir' ,
63246317 tf .cast (summed_one_hot , theta .lsh_u_emb .dtype ),
63256318 theta .lsh_u_emb )
6319+ outputs = tf .einsum ('...si,...kir->...skr' , inputs , u_mat )
6320+ outputs = activations .GetFn (p .memory_act )(outputs )
63266321 v_mat = tf .einsum ('...n,nro->...ro' ,
63276322 tf .cast (summed_one_hot , theta .lsh_v_emb .dtype ),
63286323 theta .lsh_v_emb )
6324+ outputs = tf .einsum ('...skr,...kro->...so' , outputs , v_mat )
63296325 if p .add_bias :
63306326 bias = tf .einsum ('...n,nd->...d' ,
63316327 tf .cast (summed_one_hot , theta .lsh_b_emb .dtype ),
63326328 theta .lsh_b_emb )
6333-
6334- outputs = 0.0
6335- if p .rank > 0 :
6336- outputs = tf .einsum ('...si,...kir->...skr' , inputs , u_mat )
6337- outputs = activations .GetFn (p .memory_act )(outputs )
6338- outputs = tf .einsum ('...skr,...kro->...so' , outputs , v_mat )
6339- if p .add_bias :
63406329 outputs += tf .expand_dims (tf .math .reduce_sum (bias , axis = - 2 ), axis = - 2 )
63416330 return outputs
63426331
@@ -6544,7 +6533,7 @@ def _Branch1False(cmask):
65446533 index = tf .constant (0 )
65456534 time2 = time * time
65466535 seq_mask = tf .TensorArray (dtype = tf .float32 , size = time2 )
6547- index , seq_mask = tf .while_loop (
6536+ _ , seq_mask = tf .while_loop (
65486537 lambda idx , mask : idx < time2 ,
65496538 ComputeSequenceMaskElement ,
65506539 loop_vars = [index , seq_mask ])
0 commit comments