Skip to content

Commit 3d62e89

Browse files
Internal change
PiperOrigin-RevId: 381540013
1 parent c8a9178 commit 3d62e89

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

official/nlp/keras_nlp/layers/position_embedding_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,12 @@ def test_non_default_axis_static(self):
4848
test_layer = position_embedding.PositionEmbedding(
4949
max_length=sequence_length, seq_axis=2)
5050
width = 30
51-
input_tensor = tf.keras.Input(shape=(sequence_length, width, width))
51+
input_tensor = tf.keras.Input(shape=(width, sequence_length, width))
5252
output_tensor = test_layer(input_tensor)
5353

5454
# When using static positional embedding shapes, the output is expected
5555
# to be the same as the input shape in all dimensions save batch.
56-
expected_output_shape = [None, sequence_length, width, width]
56+
expected_output_shape = [None, width, sequence_length, width]
5757
self.assertEqual(expected_output_shape, output_tensor.shape.as_list())
5858
# The default output dtype for this layer should be tf.float32.
5959
self.assertEqual(tf.float32, output_tensor.dtype)

0 commit comments

Comments
 (0)