20
20
import tensorflow as tf
21
21
22
22
from tensorflow .python .keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
23
- from official .nlp .modeling .networks import bert_dense_encoder
23
+ from official .nlp .modeling .networks import bert_encoder
24
24
25
25
26
26
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
27
27
# guarantees forward compatibility of this code for the V2 switchover.
28
28
@keras_parameterized .run_all_keras_modes
29
- class BertDenseEncoderTest (keras_parameterized .TestCase ):
29
+ class BertEncoderV2Test (keras_parameterized .TestCase ):
30
30
31
31
def tearDown (self ):
32
- super (BertDenseEncoderTest , self ).tearDown ()
32
+ super (BertEncoderV2Test , self ).tearDown ()
33
33
tf .keras .mixed_precision .set_global_policy ("float32" )
34
34
35
35
def test_dict_outputs_network_creation (self ):
36
36
hidden_size = 32
37
37
sequence_length = 21
38
38
dense_sequence_length = 20
39
- # Create a small dense BertDenseEncoder for testing.
39
+ # Create a small dense BertEncoderV2 for testing.
40
40
kwargs = {}
41
- test_network = bert_dense_encoder . BertDenseEncoder (
41
+ test_network = bert_encoder . BertEncoderV2 (
42
42
vocab_size = 100 ,
43
43
hidden_size = hidden_size ,
44
44
num_attention_heads = 2 ,
45
45
num_layers = 3 ,
46
+ with_dense_inputs = True ,
46
47
** kwargs )
47
48
# Create the inputs (note that the first dimension is implicit).
48
49
word_ids = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
@@ -86,12 +87,13 @@ def test_dict_outputs_all_encoder_outputs_network_creation(self):
86
87
sequence_length = 21
87
88
dense_sequence_length = 20
88
89
# Create a small BertEncoder for testing.
89
- test_network = bert_dense_encoder . BertDenseEncoder (
90
+ test_network = bert_encoder . BertEncoderV2 (
90
91
vocab_size = 100 ,
91
92
hidden_size = hidden_size ,
92
93
num_attention_heads = 2 ,
93
94
num_layers = 3 ,
94
- dict_outputs = True )
95
+ dict_outputs = True ,
96
+ with_dense_inputs = True )
95
97
# Create the inputs (note that the first dimension is implicit).
96
98
word_ids = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
97
99
mask = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
@@ -134,12 +136,13 @@ def test_dict_outputs_network_creation_with_float16_dtype(self):
134
136
dense_sequence_length = 20
135
137
tf .keras .mixed_precision .set_global_policy ("mixed_float16" )
136
138
# Create a small BertEncoder for testing.
137
- test_network = bert_dense_encoder . BertDenseEncoder (
139
+ test_network = bert_encoder . BertEncoderV2 (
138
140
vocab_size = 100 ,
139
141
hidden_size = hidden_size ,
140
142
num_attention_heads = 2 ,
141
143
num_layers = 3 ,
142
- dict_outputs = True )
144
+ dict_outputs = True ,
145
+ with_dense_inputs = True )
143
146
# Create the inputs (note that the first dimension is implicit).
144
147
word_ids = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
145
148
mask = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
@@ -176,9 +179,8 @@ def test_dict_outputs_network_creation_with_float16_dtype(self):
176
179
self .assertAllEqual (tf .float16 , pooled .dtype )
177
180
178
181
@parameterized .named_parameters (
179
- ("all_sequence_encoder_v2" , bert_dense_encoder .BertDenseEncoder , None ,
180
- 41 ),
181
- ("output_range_encoder_v2" , bert_dense_encoder .BertDenseEncoder , 1 , 1 ),
182
+ ("all_sequence_encoder_v2" , bert_encoder .BertEncoderV2 , None , 41 ),
183
+ ("output_range_encoder_v2" , bert_encoder .BertEncoderV2 , 1 , 1 ),
182
184
)
183
185
def test_dict_outputs_network_invocation (
184
186
self , encoder_cls , output_range , out_seq_len ):
@@ -195,7 +197,8 @@ def test_dict_outputs_network_invocation(
195
197
num_layers = 3 ,
196
198
type_vocab_size = num_types ,
197
199
output_range = output_range ,
198
- dict_outputs = True )
200
+ dict_outputs = True ,
201
+ with_dense_inputs = True )
199
202
# Create the inputs (note that the first dimension is implicit).
200
203
word_ids = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
201
204
mask = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
@@ -276,7 +279,7 @@ def test_dict_outputs_network_invocation(
276
279
277
280
# Creates a BertEncoder with embedding_width != hidden_size
278
281
embedding_width = 16
279
- test_network = bert_dense_encoder . BertDenseEncoder (
282
+ test_network = bert_encoder . BertEncoderV2 (
280
283
vocab_size = vocab_size ,
281
284
hidden_size = hidden_size ,
282
285
max_sequence_length = max_sequence_length ,
@@ -316,11 +319,12 @@ def test_embeddings_as_inputs(self):
316
319
sequence_length = 21
317
320
dense_sequence_length = 20
318
321
# Create a small BertEncoder for testing.
319
- test_network = bert_dense_encoder . BertDenseEncoder (
322
+ test_network = bert_encoder . BertEncoderV2 (
320
323
vocab_size = 100 ,
321
324
hidden_size = hidden_size ,
322
325
num_attention_heads = 2 ,
323
- num_layers = 3 )
326
+ num_layers = 3 ,
327
+ with_dense_inputs = True )
324
328
# Create the inputs (note that the first dimension is implicit).
325
329
word_ids = tf .keras .Input (shape = (sequence_length ), dtype = tf .int32 )
326
330
mask = tf .keras .Input (shape = (sequence_length ,), dtype = tf .int32 )
0 commit comments