Skip to content

Commit 993dbf5

Browse files
Jiayu Yetensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 420368411
1 parent 98a558b commit 993dbf5

File tree

3 files changed

+49
-296
lines changed

3 files changed

+49
-296
lines changed

official/nlp/modeling/networks/bert_dense_encoder.py

Lines changed: 0 additions & 276 deletions
This file was deleted.

official/nlp/modeling/networks/bert_dense_encoder_test.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,30 @@
2020
import tensorflow as tf
2121

2222
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
23-
from official.nlp.modeling.networks import bert_dense_encoder
23+
from official.nlp.modeling.networks import bert_encoder
2424

2525

2626
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
2727
# guarantees forward compatibility of this code for the V2 switchover.
2828
@keras_parameterized.run_all_keras_modes
29-
class BertDenseEncoderTest(keras_parameterized.TestCase):
29+
class BertEncoderV2Test(keras_parameterized.TestCase):
3030

3131
def tearDown(self):
32-
super(BertDenseEncoderTest, self).tearDown()
32+
super(BertEncoderV2Test, self).tearDown()
3333
tf.keras.mixed_precision.set_global_policy("float32")
3434

3535
def test_dict_outputs_network_creation(self):
3636
hidden_size = 32
3737
sequence_length = 21
3838
dense_sequence_length = 20
39-
# Create a small dense BertDenseEncoder for testing.
39+
# Create a small dense BertEncoderV2 for testing.
4040
kwargs = {}
41-
test_network = bert_dense_encoder.BertDenseEncoder(
41+
test_network = bert_encoder.BertEncoderV2(
4242
vocab_size=100,
4343
hidden_size=hidden_size,
4444
num_attention_heads=2,
4545
num_layers=3,
46+
with_dense_inputs=True,
4647
**kwargs)
4748
# Create the inputs (note that the first dimension is implicit).
4849
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
@@ -86,12 +87,13 @@ def test_dict_outputs_all_encoder_outputs_network_creation(self):
8687
sequence_length = 21
8788
dense_sequence_length = 20
8889
# Create a small BertEncoder for testing.
89-
test_network = bert_dense_encoder.BertDenseEncoder(
90+
test_network = bert_encoder.BertEncoderV2(
9091
vocab_size=100,
9192
hidden_size=hidden_size,
9293
num_attention_heads=2,
9394
num_layers=3,
94-
dict_outputs=True)
95+
dict_outputs=True,
96+
with_dense_inputs=True)
9597
# Create the inputs (note that the first dimension is implicit).
9698
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
9799
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
@@ -134,12 +136,13 @@ def test_dict_outputs_network_creation_with_float16_dtype(self):
134136
dense_sequence_length = 20
135137
tf.keras.mixed_precision.set_global_policy("mixed_float16")
136138
# Create a small BertEncoder for testing.
137-
test_network = bert_dense_encoder.BertDenseEncoder(
139+
test_network = bert_encoder.BertEncoderV2(
138140
vocab_size=100,
139141
hidden_size=hidden_size,
140142
num_attention_heads=2,
141143
num_layers=3,
142-
dict_outputs=True)
144+
dict_outputs=True,
145+
with_dense_inputs=True)
143146
# Create the inputs (note that the first dimension is implicit).
144147
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
145148
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
@@ -176,9 +179,8 @@ def test_dict_outputs_network_creation_with_float16_dtype(self):
176179
self.assertAllEqual(tf.float16, pooled.dtype)
177180

178181
@parameterized.named_parameters(
179-
("all_sequence_encoder_v2", bert_dense_encoder.BertDenseEncoder, None,
180-
41),
181-
("output_range_encoder_v2", bert_dense_encoder.BertDenseEncoder, 1, 1),
182+
("all_sequence_encoder_v2", bert_encoder.BertEncoderV2, None, 41),
183+
("output_range_encoder_v2", bert_encoder.BertEncoderV2, 1, 1),
182184
)
183185
def test_dict_outputs_network_invocation(
184186
self, encoder_cls, output_range, out_seq_len):
@@ -195,7 +197,8 @@ def test_dict_outputs_network_invocation(
195197
num_layers=3,
196198
type_vocab_size=num_types,
197199
output_range=output_range,
198-
dict_outputs=True)
200+
dict_outputs=True,
201+
with_dense_inputs=True)
199202
# Create the inputs (note that the first dimension is implicit).
200203
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
201204
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
@@ -276,7 +279,7 @@ def test_dict_outputs_network_invocation(
276279

277280
# Creates a BertEncoder with embedding_width != hidden_size
278281
embedding_width = 16
279-
test_network = bert_dense_encoder.BertDenseEncoder(
282+
test_network = bert_encoder.BertEncoderV2(
280283
vocab_size=vocab_size,
281284
hidden_size=hidden_size,
282285
max_sequence_length=max_sequence_length,
@@ -316,11 +319,12 @@ def test_embeddings_as_inputs(self):
316319
sequence_length = 21
317320
dense_sequence_length = 20
318321
# Create a small BertEncoder for testing.
319-
test_network = bert_dense_encoder.BertDenseEncoder(
322+
test_network = bert_encoder.BertEncoderV2(
320323
vocab_size=100,
321324
hidden_size=hidden_size,
322325
num_attention_heads=2,
323-
num_layers=3)
326+
num_layers=3,
327+
with_dense_inputs=True)
324328
# Create the inputs (note that the first dimension is implicit).
325329
word_ids = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
326330
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)

0 commit comments

Comments
 (0)