Skip to content

Commit 7cffe10

Browse files
aichendoubletensorflower-gardener
authored andcommitted
internal change
PiperOrigin-RevId: 338095907
1 parent ebfc313 commit 7cffe10

File tree

7 files changed

+13
-21
lines changed

7 files changed

+13
-21
lines changed

official/nlp/tasks/question_answering.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import dataclasses
2323
import orbit
2424
import tensorflow as tf
25-
import tensorflow_hub as hub
2625

2726
from official.core import base_task
2827
from official.core import config_definitions as cfg
@@ -87,11 +86,8 @@ def build_model(self):
8786
raise ValueError('At most one of `hub_module_url` and '
8887
'`init_checkpoint` can be specified.')
8988
if self.task_config.hub_module_url:
90-
hub_module = hub.load(self.task_config.hub_module_url)
91-
else:
92-
hub_module = None
93-
if hub_module:
94-
encoder_network = utils.get_encoder_from_hub(hub_module)
89+
encoder_network = utils.get_encoder_from_hub(
90+
self.task_config.hub_module_url)
9591
else:
9692
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
9793
encoder_cfg = self.task_config.model.encoder.get()

official/nlp/tasks/question_answering_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _run_task(self, config):
104104
logs = task.aggregate_logs(step_outputs=logs)
105105
metrics = task.reduce_aggregated_logs(logs)
106106
self.assertIn("final_f1", metrics)
107+
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
107108

108109
@parameterized.parameters(
109110
itertools.product(

official/nlp/tasks/sentence_prediction.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from scipy import stats
2424
from sklearn import metrics as sklearn_metrics
2525
import tensorflow as tf
26-
import tensorflow_hub as hub
2726

2827
from official.core import base_task
2928
from official.core import config_definitions as cfg
@@ -77,11 +76,8 @@ def build_model(self):
7776
raise ValueError('At most one of `hub_module_url` and '
7877
'`init_checkpoint` can be specified.')
7978
if self.task_config.hub_module_url:
80-
hub_module = hub.load(self.task_config.hub_module_url)
81-
else:
82-
hub_module = None
83-
if hub_module:
84-
encoder_network = utils.get_encoder_from_hub(hub_module)
79+
encoder_network = utils.get_encoder_from_hub(
80+
self.task_config.hub_module_url)
8581
else:
8682
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
8783
encoder_cfg = self.task_config.model.encoder.get()

official/nlp/tasks/sentence_prediction_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _run_task(self, config):
8686
iterator = iter(dataset)
8787
optimizer = tf.keras.optimizers.SGD(lr=0.1)
8888
task.train_step(next(iterator), model, optimizer, metrics=metrics)
89+
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
8990
return task.validation_step(next(iterator), model, metrics=metrics)
9091

9192
@parameterized.named_parameters(

official/nlp/tasks/tagging.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from seqeval import metrics as seqeval_metrics
2323

2424
import tensorflow as tf
25-
import tensorflow_hub as hub
2625

2726
from official.core import base_task
2827
from official.core import config_definitions as cfg
@@ -89,11 +88,8 @@ def build_model(self):
8988
raise ValueError('At most one of `hub_module_url` and '
9089
'`init_checkpoint` can be specified.')
9190
if self.task_config.hub_module_url:
92-
hub_module = hub.load(self.task_config.hub_module_url)
93-
else:
94-
hub_module = None
95-
if hub_module:
96-
encoder_network = utils.get_encoder_from_hub(hub_module)
91+
encoder_network = utils.get_encoder_from_hub(
92+
self.task_config.hub_module_url)
9793
else:
9894
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
9995

official/nlp/tasks/tagging_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def _run_task(self, config):
7373
optimizer = tf.keras.optimizers.SGD(lr=0.1)
7474
task.train_step(next(iterator), model, optimizer, metrics=metrics)
7575
task.validation_step(next(iterator), model, metrics=metrics)
76+
model.save(os.path.join(self.get_temp_dir(), "saved_model"))
7677

7778
def test_task(self):
7879
# Saves a checkpoint.

official/nlp/tasks/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import tensorflow_hub as hub
2323

2424

25-
def get_encoder_from_hub(hub_model) -> tf.keras.Model:
25+
def get_encoder_from_hub(hub_model_path: str) -> tf.keras.Model:
2626
"""Gets an encoder from hub.
2727
2828
Args:
29-
hub_model: A tfhub model loaded by `hub.load(...)`.
29+
hub_model_path: The path to the tfhub model.
3030
3131
Returns:
3232
A tf.keras.Model.
@@ -37,7 +37,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
3737
shape=(None,), dtype=tf.int32, name='input_mask')
3838
input_type_ids = tf.keras.layers.Input(
3939
shape=(None,), dtype=tf.int32, name='input_type_ids')
40-
hub_layer = hub.KerasLayer(hub_model, trainable=True)
40+
hub_layer = hub.KerasLayer(hub_model_path, trainable=True)
4141
output_dict = {}
4242
dict_input = dict(
4343
input_word_ids=input_word_ids,
@@ -49,6 +49,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
4949
# as input and returns a dict.
5050
# TODO(chendouble): Remove the support of legacy hub model when the new ones
5151
# are released.
52+
hub_model = hub.load(hub_model_path)
5253
hub_output_signature = hub_model.signatures['serving_default'].outputs
5354
if len(hub_output_signature) == 2:
5455
logging.info('Use the legacy hub module with list as input/output.')

0 commit comments

Comments
 (0)