Skip to content

Commit 1f1735b

Browse files
lingvo-botcopybara-github
authored andcommitted
Use tf.tensordot for projecting the last dimension
PiperOrigin-RevId: 489078820
1 parent cb46259 commit 1f1735b

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

lingvo/core/py_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5290,11 +5290,7 @@ def ProjectLastDim(inputs, weight, input_dim, output_dim, use_einsum=True):
52905290
outputs = tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs, weight)
52915291
else:
52925292
# not use_einsum or not use_tpu() or inputs.shape.rank >= 26
5293-
outputs = Matmul(tf.reshape(inputs, [-1, input_dim]), weight)
5294-
outputs = tf.reshape(
5295-
outputs,
5296-
tf.concat([tf.cast(GetShape(inputs)[:-1], tf.int32), [output_dim]],
5297-
axis=0))
5293+
outputs = tf.tensordot(inputs, weight, 1)
52985294

52995295
return outputs
53005296

0 commit comments

Comments
 (0)