Skip to content

Commit b045ce7

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 272777104
1 parent 0f176f6 commit b045ce7

File tree

5 files changed

+145
-150
lines changed

5 files changed

+145
-150
lines changed

official/nlp/xlnet/data_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
SEP_ID = special_symbols["<sep>"]
4444
MASK_ID = special_symbols["<mask>"]
4545
EOD_ID = special_symbols["<eod>"]
46+
SEG_ID_P = 0
47+
SEG_ID_Q = 1
48+
SEG_ID_CLS = 2
49+
SEG_ID_PAD = 3
4650

4751

4852
def file_based_input_fn_builder(input_file, name_to_features, batch_size,

official/nlp/xlnet/run_pretrain.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@
4848

4949

5050
def get_pretrainxlnet_model(model_config, run_config):
51-
model = modeling.PretrainingXLNetModel(model_config, run_config, name="model")
52-
return model
51+
return modeling.PretrainingXLNetModel(
52+
use_proj=True,
53+
xlnet_config=model_config,
54+
run_config=run_config,
55+
name="model")
5356

5457

5558
def main(unused_argv):
@@ -69,8 +72,7 @@ def main(unused_argv):
6972
if strategy:
7073
logging.info("***** Number of cores used : %d",
7174
strategy.num_replicas_in_sync)
72-
logging.info("***** Number of hosts used : %d",
73-
num_hosts)
75+
logging.info("***** Number of hosts used : %d", num_hosts)
7476
train_input_fn = functools.partial(
7577
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
7678
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,

official/nlp/xlnet/squad_utils.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,6 @@
3636

3737
SPIECE_UNDERLINE = u"▁"
3838

39-
SEG_ID_P = 0
40-
SEG_ID_Q = 1
41-
SEG_ID_CLS = 2
42-
SEG_ID_PAD = 3
43-
4439

4540
class InputFeatures(object):
4641
"""A single set of features of data."""
@@ -705,28 +700,28 @@ def _piece_to_id(x):
705700
split_token_index)
706701
token_is_max_context[len(tokens)] = is_max_context
707702
tokens.append(all_doc_tokens[split_token_index])
708-
segment_ids.append(SEG_ID_P)
703+
segment_ids.append(data_utils.SEG_ID_P)
709704
p_mask.append(0)
710705

711706
paragraph_len = len(tokens)
712707

713708
tokens.append(data_utils.SEP_ID)
714-
segment_ids.append(SEG_ID_P)
709+
segment_ids.append(data_utils.SEG_ID_P)
715710
p_mask.append(1)
716711

717712
# note(zhiliny): we put P before Q
718713
# because during pretraining, B is always shorter than A
719714
for token in query_tokens:
720715
tokens.append(token)
721-
segment_ids.append(SEG_ID_Q)
716+
segment_ids.append(data_utils.SEG_ID_Q)
722717
p_mask.append(1)
723718
tokens.append(data_utils.SEP_ID)
724-
segment_ids.append(SEG_ID_Q)
719+
segment_ids.append(data_utils.SEG_ID_Q)
725720
p_mask.append(1)
726721

727722
cls_index = len(segment_ids)
728723
tokens.append(data_utils.CLS_ID)
729-
segment_ids.append(SEG_ID_CLS)
724+
segment_ids.append(data_utils.SEG_ID_CLS)
730725
p_mask.append(0)
731726

732727
input_ids = tokens
@@ -739,7 +734,7 @@ def _piece_to_id(x):
739734
while len(input_ids) < max_seq_length:
740735
input_ids.append(0)
741736
input_mask.append(1)
742-
segment_ids.append(SEG_ID_PAD)
737+
segment_ids.append(data_utils.SEG_ID_PAD)
743738
p_mask.append(1)
744739

745740
assert len(input_ids) == max_seq_length

official/nlp/xlnet_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
3030
kwargs = dict(
3131
is_training=is_training,
3232
use_tpu=flags.use_tpu,
33-
use_bfloat16=flags.use_bfloat16,
3433
dropout=flags.dropout,
3534
dropout_att=flags.dropout_att,
3635
init_method=flags.init_method,
@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
4948
return RunConfig(**kwargs)
5049

5150

51+
# TODO(hongkuny): refactor XLNetConfig and RunConfig.
5252
class XLNetConfig(object):
5353
"""Configs for XLNet model.
5454
@@ -131,7 +131,6 @@ class RunConfig(object):
131131
def __init__(self,
132132
is_training,
133133
use_tpu,
134-
use_bfloat16,
135134
dropout,
136135
dropout_att,
137136
init_method='normal',
@@ -141,13 +140,13 @@ def __init__(self,
141140
reuse_len=None,
142141
bi_data=False,
143142
clamp_len=-1,
144-
same_length=False):
143+
same_length=False,
144+
use_cls_mask=True):
145145
"""Initializes RunConfig.
146146
147147
Args:
148148
is_training: bool, whether in training mode.
149149
use_tpu: bool, whether TPUs are used.
150-
use_bfloat16: bool, use bfloat16 instead of float32.
151150
dropout: float, dropout rate.
152151
dropout_att: float, dropout rate on attention probabilities.
153152
init_method: str, the initialization scheme, either "normal" or "uniform".
@@ -164,6 +163,7 @@ def __init__(self,
164163
-1 means no clamping.
165164
same_length: bool, whether to use the same attention length
166165
for each token.
166+
use_cls_mask: bool, whether to introduce cls mask.
167167
"""
168168

169169
self.init_method = init_method
@@ -173,9 +173,9 @@ def __init__(self,
173173
self.dropout = dropout
174174
self.dropout_att = dropout_att
175175
self.use_tpu = use_tpu
176-
self.use_bfloat16 = use_bfloat16
177176
self.mem_len = mem_len
178177
self.reuse_len = reuse_len
179178
self.bi_data = bi_data
180179
self.clamp_len = clamp_len
181180
self.same_length = same_length
181+
self.use_cls_mask = use_cls_mask

0 commit comments

Comments
 (0)