22
22
import tensorflow_hub as hub
23
23
24
24
25
- def get_encoder_from_hub (hub_model ) -> tf .keras .Model :
25
+ def get_encoder_from_hub (hub_model_path : str ) -> tf .keras .Model :
26
26
"""Gets an encoder from hub.
27
27
28
28
Args:
29
- hub_model: A tfhub model loaded by `hub.load(...)` .
29
+ hub_model_path: The path to the tfhub model .
30
30
31
31
Returns:
32
32
A tf.keras.Model.
@@ -37,7 +37,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
37
37
shape = (None ,), dtype = tf .int32 , name = 'input_mask' )
38
38
input_type_ids = tf .keras .layers .Input (
39
39
shape = (None ,), dtype = tf .int32 , name = 'input_type_ids' )
40
- hub_layer = hub .KerasLayer (hub_model , trainable = True )
40
+ hub_layer = hub .KerasLayer (hub_model_path , trainable = True )
41
41
output_dict = {}
42
42
dict_input = dict (
43
43
input_word_ids = input_word_ids ,
@@ -49,6 +49,7 @@ def get_encoder_from_hub(hub_model) -> tf.keras.Model:
49
49
# as input and returns a dict.
50
50
# TODO(chendouble): Remove the support of legacy hub model when the new ones
51
51
# are released.
52
+ hub_model = hub .load (hub_model_path )
52
53
hub_output_signature = hub_model .signatures ['serving_default' ].outputs
53
54
if len (hub_output_signature ) == 2 :
54
55
logging .info ('Use the legacy hub module with list as input/output.' )
0 commit comments