Skip to content

Commit d3a3f14

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 423199224
1 parent 8b4d459 commit d3a3f14

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

official/nlp/data/pretrain_dynamic_dataloader.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,29 @@ def __init__(self, params):
7979
def _decode(self, record: tf.Tensor):
8080
"""Decodes a serialized tf.Example."""
8181
name_to_features = {
82-
'input_ids': tf.io.VarLenFeature(tf.int64),
8382
'input_mask': tf.io.VarLenFeature(tf.int64),
84-
'segment_ids': tf.io.VarLenFeature(tf.int64),
8583
'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
8684
'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
8785
'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
8886
}
87+
if self._params.use_v2_feature_names:
88+
input_ids_key = 'input_word_ids'
89+
segment_key = 'input_type_ids'
90+
name_to_features.update({
91+
input_ids_key: tf.io.VarLenFeature(tf.int64),
92+
segment_key: tf.io.VarLenFeature(tf.int64),
93+
})
94+
else:
95+
input_ids_key = 'input_ids'
96+
segment_key = 'segment_ids'
97+
name_to_features.update({
98+
input_ids_key: tf.io.VarLenFeature(tf.int64),
99+
segment_key: tf.io.VarLenFeature(tf.int64),
100+
})
89101
if self._use_next_sentence_label:
90102
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
91103
tf.int64)
92-
dynamic_keys = ['input_ids', 'input_mask', 'segment_ids']
104+
dynamic_keys = [input_ids_key, 'input_mask', segment_key]
93105
if self._use_position_id:
94106
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
95107
dynamic_keys.append('position_ids')
@@ -102,7 +114,7 @@ def _decode(self, record: tf.Tensor):
102114
# sequence length dimension.
103115
# Pad before the first non pad from the back should not be removed.
104116
mask = tf.math.greater(
105-
tf.math.cumsum(example['input_ids'], reverse=True), 0)
117+
tf.math.cumsum(example[input_ids_key], reverse=True), 0)
106118
for key in dynamic_keys:
107119
example[key] = tf.boolean_mask(example[key], mask)
108120

0 commit comments

Comments
 (0)