@@ -331,17 +331,14 @@ def test_several_attention_axes(self, attention_axes):
331
331
self .assertEqual (data_tensor .shape .as_list (), output_tensor .shape .as_list ())
332
332
333
333
@parameterized .named_parameters (
334
- ('plain' , False , False , False ),
335
- ('plain_returnscore' , False , True , False ),
336
- ('plain_with_relative_pe' , False , False , True ),
337
- ('reuse_all' , True , False , False ),
338
- ('reuse_all_returnscore' , True , True , False ),
339
- ('reuse_all_with_relative_pe' , True , False , True ),
340
- ('reuse_5' , 5 , False , False ),
341
- ('reuse_5_returnscore' , 5 , True , False ),
342
- ('reuse_5_with_relative_pe' , 5 , False , True ),)
343
- def test_layer_invocation_with_mask (self , reuse_attention ,
344
- return_attention_scores , use_relative_pe ):
334
+ ('plain_returnscore' , False , False ),
335
+ ('plain_with_relative_pe' , False , True ),
336
+ ('reuse_all_returnscore' , True , False ),
337
+ ('reuse_all_with_relative_pe' , True , True ),
338
+ ('reuse_5_returnscore' , 5 , False ),
339
+ ('reuse_5_with_relative_pe' , 5 , True ),
340
+ )
341
+ def test_layer_invocation_with_mask (self , reuse_attention , use_relative_pe ):
345
342
test_layer = reuse_transformer .ReuseTransformer (
346
343
num_attention_heads = 10 ,
347
344
inner_dim = 2048 ,
@@ -354,16 +351,20 @@ def test_layer_invocation_with_mask(self, reuse_attention,
354
351
data_tensor = tf_keras .Input (shape = (sequence_length , width ))
355
352
# Create a 2-dimensional input (the first dimension is implicit).
356
353
mask_tensor = tf_keras .Input (shape = (sequence_length , sequence_length ))
357
- return_scores_tensor = tf_keras .Input (shape = (1 ,))
358
354
reuse_attention_scores = tf_keras .Input (
359
355
shape = (10 , sequence_length , sequence_length ))
360
356
output_tensor , _ = test_layer (
361
357
[data_tensor , mask_tensor , reuse_attention_scores ])
362
358
363
359
# Create a model from the test layer.
364
360
model = tf_keras .Model (
365
- ([data_tensor , mask_tensor , reuse_attention_scores ],
366
- return_scores_tensor ), output_tensor )
361
+ [
362
+ data_tensor ,
363
+ mask_tensor ,
364
+ reuse_attention_scores ,
365
+ ],
366
+ output_tensor ,
367
+ )
367
368
368
369
# Invoke the model on test data. We can't validate the output data itself
369
370
# (the NN is too complex) but this will rule out structural runtime errors.
@@ -376,8 +377,7 @@ def test_layer_invocation_with_mask(self, reuse_attention,
376
377
2 , size = (batch_size , sequence_length , sequence_length ))
377
378
reuse_scores = np .random .rand (
378
379
batch_size , 10 , sequence_length , sequence_length )
379
- _ = model .predict ([input_data , mask_data , reuse_scores ],
380
- return_attention_scores )
380
+ _ = model .predict ([input_data , mask_data , reuse_scores ])
381
381
382
382
@parameterized .named_parameters (
383
383
('without_relative_pe_with_pe_max_seq_length_10' , False , 10 ),
0 commit comments