@@ -79,17 +79,29 @@ def __init__(self, params):
79
79
def _decode (self , record : tf .Tensor ):
80
80
"""Decodes a serialized tf.Example."""
81
81
name_to_features = {
82
- 'input_ids' : tf .io .VarLenFeature (tf .int64 ),
83
82
'input_mask' : tf .io .VarLenFeature (tf .int64 ),
84
- 'segment_ids' : tf .io .VarLenFeature (tf .int64 ),
85
83
'masked_lm_positions' : tf .io .VarLenFeature (tf .int64 ),
86
84
'masked_lm_ids' : tf .io .VarLenFeature (tf .int64 ),
87
85
'masked_lm_weights' : tf .io .VarLenFeature (tf .float32 ),
88
86
}
87
+ if self ._params .use_v2_feature_names :
88
+ input_ids_key = 'input_word_ids'
89
+ segment_key = 'input_type_ids'
90
+ name_to_features .update ({
91
+ input_ids_key : tf .io .VarLenFeature (tf .int64 ),
92
+ segment_key : tf .io .VarLenFeature (tf .int64 ),
93
+ })
94
+ else :
95
+ input_ids_key = 'input_ids'
96
+ segment_key = 'segment_ids'
97
+ name_to_features .update ({
98
+ input_ids_key : tf .io .VarLenFeature (tf .int64 ),
99
+ segment_key : tf .io .VarLenFeature (tf .int64 ),
100
+ })
89
101
if self ._use_next_sentence_label :
90
102
name_to_features ['next_sentence_labels' ] = tf .io .FixedLenFeature ([1 ],
91
103
tf .int64 )
92
- dynamic_keys = ['input_ids' , 'input_mask' , 'segment_ids' ]
104
+ dynamic_keys = [input_ids_key , 'input_mask' , segment_key ]
93
105
if self ._use_position_id :
94
106
name_to_features ['position_ids' ] = tf .io .VarLenFeature (tf .int64 )
95
107
dynamic_keys .append ('position_ids' )
@@ -102,7 +114,7 @@ def _decode(self, record: tf.Tensor):
102
114
# sequence length dimension.
103
115
# Pad before the first non pad from the back should not be removed.
104
116
mask = tf .math .greater (
105
- tf .math .cumsum (example ['input_ids' ], reverse = True ), 0 )
117
+ tf .math .cumsum (example [input_ids_key ], reverse = True ), 0 )
106
118
for key in dynamic_keys :
107
119
example [key ] = tf .boolean_mask (example [key ], mask )
108
120
0 commit comments