Skip to content

Commit 48e4f56

Browse files
lingvo-botcopybara-github
authored andcommitted
Fix some pylint warnings
PiperOrigin-RevId: 479674243
1 parent b33f92d commit 48e4f56

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

lingvo/core/layers.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def FPropMeta(cls, p, inputs, *args):
170170
return py_utils.NestedMap(flops=0, out_shapes=(inputs,))
171171

172172

173-
# TODO(yonghui/jonathanasdf): Remove the forwarded links.
173+
# TODO(yonghui, jonathanasdf): Remove the forwarded links.
174174
_ComputeConvOutputShape = conv_layers_with_time_padding.ComputeConvOutputShape
175175
_ComputeConvOutputPadding = (
176176
conv_layers_with_time_padding.ComputeConvOutputPadding)
@@ -939,7 +939,7 @@ def __init__(self, params):
939939
act_multiplier = activations.DimMultiplier(p.activation)
940940
self._internal_output_dim = p.output_dim * act_multiplier
941941
if act_multiplier > 1:
942-
assert not p.affine_last, ('Affine last does not support GLU variants.')
942+
assert not p.affine_last, 'Affine last does not support GLU variants.'
943943
assert not p.use_blocked_matmul, (
944944
'Blocked matmul does not support GLU variants.')
945945
assert not p.use_block_diagonal_matmul, (
@@ -995,10 +995,8 @@ def _GetBlockedWeightMatrix(self, w):
995995
w = tf.slice(w, [0, 0, 0], [p.input_dim, w_om, block_dim])
996996
return w
997997

998-
def _GetBlockDiagonalInitScale(self, num_blocks, dense_shape, dtype=None):
998+
def _GetBlockDiagonalInitScale(self, num_blocks, dense_shape):
999999
m, n = dense_shape
1000-
if not dtype:
1001-
dtype = tf.float32
10021000
scale = math.sqrt(6.0 / (m // num_blocks + n // num_blocks))
10031001
return scale
10041002

@@ -1017,8 +1015,8 @@ def _CreateLayerVariables(self):
10171015
shape=(p.bd_num_blocks, p.input_dim // p.bd_num_blocks,
10181016
p.output_dim // p.bd_num_blocks),
10191017
init=py_utils.WeightInit.Xavier(
1020-
scale=self._GetBlockDiagonalInitScale(
1021-
p.bd_num_blocks, (p.input_dim, p.output_dim), dtype=p.dtype)),
1018+
scale=self._GetBlockDiagonalInitScale(p.bd_num_blocks, (
1019+
p.input_dim, p.output_dim))),
10221020
dtype=p.dtype,
10231021
device_mesh=p.device_mesh,
10241022
tensor_split_dims_mapping=p.weight_split_dims_mapping,
@@ -1585,7 +1583,6 @@ def __init__(self, params):
15851583
self.CreateChildren('fc', params_fc_layers)
15861584
self.CreateChildren('dropout', params_dropout_layers)
15871585

1588-
memory_params = None
15891586
if p.memory_augmentation:
15901587
assert p.memory is not None
15911588
memory_params = p.memory.Copy()
@@ -3907,10 +3904,6 @@ def XentLossFromLogits(self,
39073904
assert logits is not None
39083905
per_example_argmax = py_utils.ArgMax(logits)
39093906

3910-
# For compatibility with other softmax implementations.
3911-
if (class_weights is not None and
3912-
py_utils.GetRank(class_weights) == py_utils.GetRank(logits)):
3913-
class_weights = tf.squeeze(class_weights, -1)
39143907
if (class_ids is not None and
39153908
py_utils.GetRank(class_ids) == py_utils.GetRank(logits)):
39163909
class_ids = tf.squeeze(class_ids, -1)
@@ -4641,7 +4634,7 @@ def FPropMeta(cls, p, inputs):
46414634
flops=inputs.num_elements() * 10, out_shapes=(inputs,))
46424635

46434636

4644-
# TODO(shibow/wangtao) remove this after b/174094694 is done.
4637+
# TODO(shibow, wangtao) remove this after b/174094694 is done.
46454638
class ReshapedLayerNorm(LayerNorm):
46464639
"""Customized LayerNorm with model dim D reshaped as Md."""
46474640

@@ -5273,7 +5266,7 @@ def __init__(self, params):
52735266
if not p.name:
52745267
raise ValueError('Layer must have a specified name!')
52755268

5276-
assert p.num_sources > 0, ('Must specify num_sources > 0.')
5269+
assert p.num_sources > 0, 'Must specify num_sources > 0.'
52775270

52785271
if p.weighted_merger_dropout_prob > 0.0:
52795272
dropout_tpl = DropoutLayer.Params()
@@ -6076,11 +6069,10 @@ def FProp(self, theta, inputs, paddings=None):
60766069
def FPropMeta(cls, p, inputs, paddings=None):
60776070
py_utils.CheckShapes((inputs,))
60786071
assert inputs[-1] == p.input_dim
6079-
flops = 0
60806072
in_dim = inputs[-1]
60816073
other_dims = inputs.num_elements() / in_dim
60826074
flops = 5 * other_dims * in_dim * p.hidden_layer_dim
6083-
flops = 5 * other_dims * p.num_outputs * p.hidden_layer_dim
6075+
flops += 5 * other_dims * p.num_outputs * p.hidden_layer_dim
60846076
out_shape = tshape.Shape(inputs[:-1] + [symbolic.ToStatic(p.num_outputs)])
60856077
return py_utils.NestedMap(flops=flops, out_shapes=(out_shape,))
60866078

@@ -6319,24 +6311,21 @@ def FProp(self, theta, inputs):
63196311
bucket_one_hot = tf.one_hot(bucket_ids, num_buckets, axis=-1)
63206312
summed_one_hot = tf.math.reduce_mean(bucket_one_hot, axis=-3)
63216313

6314+
outputs = 0.0
63226315
if p.rank > 0:
63236316
u_mat = tf.einsum('...n,nir->...ir',
63246317
tf.cast(summed_one_hot, theta.lsh_u_emb.dtype),
63256318
theta.lsh_u_emb)
6319+
outputs = tf.einsum('...si,...kir->...skr', inputs, u_mat)
6320+
outputs = activations.GetFn(p.memory_act)(outputs)
63266321
v_mat = tf.einsum('...n,nro->...ro',
63276322
tf.cast(summed_one_hot, theta.lsh_v_emb.dtype),
63286323
theta.lsh_v_emb)
6324+
outputs = tf.einsum('...skr,...kro->...so', outputs, v_mat)
63296325
if p.add_bias:
63306326
bias = tf.einsum('...n,nd->...d',
63316327
tf.cast(summed_one_hot, theta.lsh_b_emb.dtype),
63326328
theta.lsh_b_emb)
6333-
6334-
outputs = 0.0
6335-
if p.rank > 0:
6336-
outputs = tf.einsum('...si,...kir->...skr', inputs, u_mat)
6337-
outputs = activations.GetFn(p.memory_act)(outputs)
6338-
outputs = tf.einsum('...skr,...kro->...so', outputs, v_mat)
6339-
if p.add_bias:
63406329
outputs += tf.expand_dims(tf.math.reduce_sum(bias, axis=-2), axis=-2)
63416330
return outputs
63426331

@@ -6544,7 +6533,7 @@ def _Branch1False(cmask):
65446533
index = tf.constant(0)
65456534
time2 = time * time
65466535
seq_mask = tf.TensorArray(dtype=tf.float32, size=time2)
6547-
index, seq_mask = tf.while_loop(
6536+
_, seq_mask = tf.while_loop(
65486537
lambda idx, mask: idx < time2,
65496538
ComputeSequenceMaskElement,
65506539
loop_vars=[index, seq_mask])

0 commit comments

Comments
 (0)