24
24
from typing import Text
25
25
26
26
from official .nlp import bert_modeling
27
+ from official .nlp import bert_models
27
28
28
29
FLAGS = flags .FLAGS
29
30
30
31
flags .DEFINE_string ("bert_config_file" , None ,
31
32
"Bert configuration file to define core bert layers." )
32
33
flags .DEFINE_string ("model_checkpoint_path" , None ,
33
34
"File path to TF model checkpoint." )
34
- flags .DEFINE_string ("export_path" , None ,
35
- "TF-Hub SavedModel destination path." )
35
+ flags .DEFINE_string ("export_path" , None , "TF-Hub SavedModel destination path." )
36
36
flags .DEFINE_string ("vocab_file" , None ,
37
37
"The vocabulary file that the BERT model was trained on." )
38
38
@@ -53,21 +53,23 @@ def create_bert_model(bert_config: bert_modeling.BertConfig):
53
53
shape = (None ,), dtype = tf .int32 , name = "input_mask" )
54
54
input_type_ids = tf .keras .layers .Input (
55
55
shape = (None ,), dtype = tf .int32 , name = "input_type_ids" )
56
- return bert_modeling .get_bert_model (
57
- input_word_ids ,
58
- input_mask ,
59
- input_type_ids ,
60
- config = bert_config ,
61
- name = "bert_model" ,
62
- float_type = tf .float32 )
56
+ transformer_encoder = bert_models .get_transformer_encoder (
57
+ bert_config , sequence_length = None , float_dtype = tf .float32 )
58
+ sequence_output , pooled_output = transformer_encoder (
59
+ [input_word_ids , input_mask , input_type_ids ])
60
+ # To keep consistent with legacy hub modules, the outputs are
61
+ # "pooled_output" and "sequence_output".
62
+ return tf .keras .Model (
63
+ inputs = [input_word_ids , input_mask , input_type_ids ],
64
+ outputs = [pooled_output , sequence_output ]), transformer_encoder
63
65
64
66
65
67
def export_bert_tfhub (bert_config : bert_modeling .BertConfig ,
66
68
model_checkpoint_path : Text , hub_destination : Text ,
67
69
vocab_file : Text ):
68
70
"""Restores a tf.keras.Model and saves for TF-Hub."""
69
- core_model = create_bert_model (bert_config )
70
- checkpoint = tf .train .Checkpoint (model = core_model )
71
+ core_model , encoder = create_bert_model (bert_config )
72
+ checkpoint = tf .train .Checkpoint (model = encoder )
71
73
checkpoint .restore (model_checkpoint_path ).assert_consumed ()
72
74
core_model .vocab_file = tf .saved_model .Asset (vocab_file )
73
75
core_model .do_lower_case = tf .Variable (
@@ -79,8 +81,8 @@ def main(_):
79
81
assert tf .version .VERSION .startswith ('2.' )
80
82
81
83
bert_config = bert_modeling .BertConfig .from_json_file (FLAGS .bert_config_file )
82
- export_bert_tfhub (bert_config , FLAGS .model_checkpoint_path ,
83
- FLAGS .export_path , FLAGS . vocab_file )
84
+ export_bert_tfhub (bert_config , FLAGS .model_checkpoint_path , FLAGS . export_path ,
85
+ FLAGS .vocab_file )
84
86
85
87
86
88
if __name__ == "__main__" :
0 commit comments