Skip to content

Commit 553a4f4

Browse files
aichendoubletensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 286634090
1 parent 5f7bdb1 commit 553a4f4

File tree

3 files changed

+32
-23
lines changed

3 files changed

+32
-23
lines changed

official/nlp/bert/export_tfhub.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@
2424
from typing import Text
2525

2626
from official.nlp import bert_modeling
27+
from official.nlp import bert_models
2728

2829
FLAGS = flags.FLAGS
2930

3031
flags.DEFINE_string("bert_config_file", None,
3132
"Bert configuration file to define core bert layers.")
3233
flags.DEFINE_string("model_checkpoint_path", None,
3334
"File path to TF model checkpoint.")
34-
flags.DEFINE_string("export_path", None,
35-
"TF-Hub SavedModel destination path.")
35+
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
3636
flags.DEFINE_string("vocab_file", None,
3737
"The vocabulary file that the BERT model was trained on.")
3838

@@ -53,21 +53,23 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
5353
shape=(None,), dtype=tf.int32, name="input_mask")
5454
input_type_ids = tf.keras.layers.Input(
5555
shape=(None,), dtype=tf.int32, name="input_type_ids")
56-
return bert_modeling.get_bert_model(
57-
input_word_ids,
58-
input_mask,
59-
input_type_ids,
60-
config=bert_config,
61-
name="bert_model",
62-
float_type=tf.float32)
56+
transformer_encoder = bert_models.get_transformer_encoder(
57+
bert_config, sequence_length=None, float_dtype=tf.float32)
58+
sequence_output, pooled_output = transformer_encoder(
59+
[input_word_ids, input_mask, input_type_ids])
60+
# To keep consistent with legacy hub modules, the outputs are
61+
# "pooled_output" and "sequence_output".
62+
return tf.keras.Model(
63+
inputs=[input_word_ids, input_mask, input_type_ids],
64+
outputs=[pooled_output, sequence_output]), transformer_encoder
6365

6466

6567
def export_bert_tfhub(bert_config: bert_modeling.BertConfig,
6668
model_checkpoint_path: Text, hub_destination: Text,
6769
vocab_file: Text):
6870
"""Restores a tf.keras.Model and saves for TF-Hub."""
69-
core_model = create_bert_model(bert_config)
70-
checkpoint = tf.train.Checkpoint(model=core_model)
71+
core_model, encoder = create_bert_model(bert_config)
72+
checkpoint = tf.train.Checkpoint(model=encoder)
7173
checkpoint.restore(model_checkpoint_path).assert_consumed()
7274
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
7375
core_model.do_lower_case = tf.Variable(
@@ -79,8 +81,8 @@ def main(_):
7981
assert tf.version.VERSION.startswith('2.')
8082

8183
bert_config = bert_modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
82-
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
83-
FLAGS.export_path, FLAGS.vocab_file)
84+
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, FLAGS.export_path,
85+
FLAGS.vocab_file)
8486

8587

8688
if __name__ == "__main__":

official/nlp/bert/export_tfhub_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def test_export_tfhub(self):
3939
max_position_embeddings=128,
4040
num_attention_heads=2,
4141
num_hidden_layers=1)
42-
bert_model = export_tfhub.create_bert_model(bert_config)
42+
bert_model, encoder = export_tfhub.create_bert_model(bert_config)
4343
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
44-
checkpoint = tf.train.Checkpoint(model=bert_model)
44+
checkpoint = tf.train.Checkpoint(model=encoder)
4545
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
4646
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
4747

@@ -70,10 +70,17 @@ def test_export_tfhub(self):
7070
dummy_ids = np.zeros((2, 10), dtype=np.int32)
7171
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
7272
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
73+
74+
# The outputs of hub module are "pooled_output" and "sequence_output",
75+
# while the outputs of encoder is in reversed order, i.e.,
76+
# "sequence_output" and "pooled_output".
77+
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
7378
self.assertEqual(hub_outputs[0].shape, (2, 16))
7479
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
75-
for source_output, hub_output in zip(source_outputs, hub_outputs):
80+
for source_output, hub_output, encoder_output in zip(
81+
source_outputs, hub_outputs, encoder_outputs):
7682
self.assertAllClose(source_output.numpy(), hub_output.numpy())
83+
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
7784

7885

7986
if __name__ == "__main__":

official/nlp/bert_models.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def call(self, inputs):
134134
return final_loss
135135

136136

137-
def _get_transformer_encoder(bert_config,
138-
sequence_length,
139-
float_dtype=tf.float32):
137+
def get_transformer_encoder(bert_config,
138+
sequence_length,
139+
float_dtype=tf.float32):
140140
"""Gets a 'TransformerEncoder' object.
141141
142142
Args:
@@ -206,7 +206,7 @@ def pretrain_model(bert_config,
206206
next_sentence_labels = tf.keras.layers.Input(
207207
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
208208

209-
transformer_encoder = _get_transformer_encoder(bert_config, seq_length)
209+
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
210210
if initializer is None:
211211
initializer = tf.keras.initializers.TruncatedNormal(
212212
stddev=bert_config.initializer_range)
@@ -294,8 +294,8 @@ def squad_model(bert_config,
294294
initializer = tf.keras.initializers.TruncatedNormal(
295295
stddev=bert_config.initializer_range)
296296
if not hub_module_url:
297-
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length,
298-
float_type)
297+
bert_encoder = get_transformer_encoder(bert_config, max_seq_length,
298+
float_type)
299299
return bert_span_labeler.BertSpanLabeler(
300300
network=bert_encoder, initializer=initializer), bert_encoder
301301

@@ -359,7 +359,7 @@ def classifier_model(bert_config,
359359
stddev=bert_config.initializer_range)
360360

361361
if not hub_module_url:
362-
bert_encoder = _get_transformer_encoder(bert_config, max_seq_length)
362+
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
363363
return bert_classifier.BertClassifier(
364364
bert_encoder,
365365
num_classes=num_labels,

0 commit comments

Comments
 (0)