Skip to content

Commit 9a940e7

Browse files
yeqinglitensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 597592696
1 parent 5c0617a commit 9a940e7

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

official/nlp/modeling/layers/reuse_transformer_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -331,17 +331,14 @@ def test_several_attention_axes(self, attention_axes):
331331
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list())
332332

333333
@parameterized.named_parameters(
334-
('plain', False, False, False),
335-
('plain_returnscore', False, True, False),
336-
('plain_with_relative_pe', False, False, True),
337-
('reuse_all', True, False, False),
338-
('reuse_all_returnscore', True, True, False),
339-
('reuse_all_with_relative_pe', True, False, True),
340-
('reuse_5', 5, False, False),
341-
('reuse_5_returnscore', 5, True, False),
342-
('reuse_5_with_relative_pe', 5, False, True),)
343-
def test_layer_invocation_with_mask(self, reuse_attention,
344-
return_attention_scores, use_relative_pe):
334+
('plain_returnscore', False, False),
335+
('plain_with_relative_pe', False, True),
336+
('reuse_all_returnscore', True, False),
337+
('reuse_all_with_relative_pe', True, True),
338+
('reuse_5_returnscore', 5, False),
339+
('reuse_5_with_relative_pe', 5, True),
340+
)
341+
def test_layer_invocation_with_mask(self, reuse_attention, use_relative_pe):
345342
test_layer = reuse_transformer.ReuseTransformer(
346343
num_attention_heads=10,
347344
inner_dim=2048,
@@ -354,16 +351,20 @@ def test_layer_invocation_with_mask(self, reuse_attention,
354351
data_tensor = tf_keras.Input(shape=(sequence_length, width))
355352
# Create a 2-dimensional input (the first dimension is implicit).
356353
mask_tensor = tf_keras.Input(shape=(sequence_length, sequence_length))
357-
return_scores_tensor = tf_keras.Input(shape=(1,))
358354
reuse_attention_scores = tf_keras.Input(
359355
shape=(10, sequence_length, sequence_length))
360356
output_tensor, _ = test_layer(
361357
[data_tensor, mask_tensor, reuse_attention_scores])
362358

363359
# Create a model from the test layer.
364360
model = tf_keras.Model(
365-
([data_tensor, mask_tensor, reuse_attention_scores],
366-
return_scores_tensor), output_tensor)
361+
[
362+
data_tensor,
363+
mask_tensor,
364+
reuse_attention_scores,
365+
],
366+
output_tensor,
367+
)
367368

368369
# Invoke the model on test data. We can't validate the output data itself
369370
# (the NN is too complex) but this will rule out structural runtime errors.
@@ -376,8 +377,7 @@ def test_layer_invocation_with_mask(self, reuse_attention,
376377
2, size=(batch_size, sequence_length, sequence_length))
377378
reuse_scores = np.random.rand(
378379
batch_size, 10, sequence_length, sequence_length)
379-
_ = model.predict([input_data, mask_data, reuse_scores],
380-
return_attention_scores)
380+
_ = model.predict([input_data, mask_data, reuse_scores])
381381

382382
@parameterized.named_parameters(
383383
('without_relative_pe_with_pe_max_seq_length_10', False, 10),

0 commit comments

Comments
 (0)