Skip to content

Commit 7d8a218

Browse files
No public description
PiperOrigin-RevId: 595636151
1 parent 00aa43b commit 7d8a218

File tree

2 files changed

+85
-35
lines changed

2 files changed

+85
-35
lines changed

official/nlp/data/create_pretraining_data.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import random
2020

2121
# Import libraries
22+
2223
from absl import app
2324
from absl import flags
2425
from absl import logging
@@ -35,17 +36,37 @@
3536
"output_file", None,
3637
"Output TF example file (or comma-separated list of files).")
3738

38-
flags.DEFINE_string("vocab_file", None,
39-
"The vocabulary file that the BERT model was trained on.")
39+
flags.DEFINE_enum(
40+
"tokenization",
41+
"WordPiece",
42+
["WordPiece", "SentencePiece"],
43+
"Specifies the tokenizer implementation, i.e., whether to use WordPiece "
44+
"or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
45+
"while ALBERT uses SentencePiece tokenizer.",
46+
)
47+
48+
flags.DEFINE_string(
49+
"vocab_file",
50+
None,
51+
"For WordPiece tokenization, the vocabulary file of the tokenizer.",
52+
)
53+
54+
flags.DEFINE_string(
55+
"sp_model_file",
56+
"",
57+
"For SentencePiece tokenization, the path to the model of the tokenizer.",
58+
)
4059

4160
flags.DEFINE_bool(
4261
"do_lower_case", True,
4362
"Whether to lower case the input text. Should be True for uncased "
4463
"models and False for cased models.")
4564

4665
flags.DEFINE_bool(
47-
"do_whole_word_mask", False,
48-
"Whether to use whole word masking rather than per-WordPiece masking.")
66+
"do_whole_word_mask",
67+
False,
68+
"Whether to use whole word masking rather than per-token masking.",
69+
)
4970

5071
flags.DEFINE_integer(
5172
"max_ngram_size", None,
@@ -198,16 +219,19 @@ def create_float_feature(values):
198219
return feature
199220

200221

201-
def create_training_instances(input_files,
202-
tokenizer,
203-
max_seq_length,
204-
dupe_factor,
205-
short_seq_prob,
206-
masked_lm_prob,
207-
max_predictions_per_seq,
208-
rng,
209-
do_whole_word_mask=False,
210-
max_ngram_size=None):
222+
def create_training_instances(
223+
input_files,
224+
tokenizer,
225+
processor_text_fn,
226+
max_seq_length,
227+
dupe_factor,
228+
short_seq_prob,
229+
masked_lm_prob,
230+
max_predictions_per_seq,
231+
rng,
232+
do_whole_word_mask=False,
233+
max_ngram_size=None,
234+
):
211235
"""Create `TrainingInstance`s from raw text."""
212236
all_documents = [[]]
213237

@@ -219,11 +243,8 @@ def create_training_instances(input_files,
219243
# that the "next sentence prediction" task doesn't span between documents.
220244
for input_file in input_files:
221245
with tf.io.gfile.GFile(input_file, "rb") as reader:
222-
while True:
223-
line = tokenization.convert_to_unicode(reader.readline())
224-
if not line:
225-
break
226-
line = line.strip()
246+
for line in reader:
247+
line = processor_text_fn(line)
227248

228249
# Empty lines are used as document delimiters
229250
if not line:
@@ -535,15 +556,16 @@ def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
535556
return output_ngrams
536557

537558

538-
def _wordpieces_to_grams(tokens):
559+
def _tokens_to_grams(tokens):
539560
"""Reconstitue grams (words) from `tokens`.
540561
541562
E.g.,
542563
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
543564
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
544565
545566
Args:
546-
tokens: list of wordpieces
567+
tokens: list of tokens (word pieces or sentence pieces).
568+
547569
Returns:
548570
List of _Grams representing spans of whole words
549571
(without "[CLS]" and "[SEP]").
@@ -570,7 +592,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
570592
max_ngram_size=None):
571593
"""Creates the predictions for the masked LM objective."""
572594
if do_whole_word_mask:
573-
grams = _wordpieces_to_grams(tokens)
595+
grams = _tokens_to_grams(tokens)
574596
else:
575597
# Here we consider each token to be a word to allow for sub-word masking.
576598
if max_ngram_size:
@@ -633,9 +655,28 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
633655
trunc_tokens.pop()
634656

635657

658+
def get_processor_text_fn(is_sentence_piece, do_lower_case):
659+
def processor_text_fn(text):
660+
text = tokenization.convert_to_unicode(text)
661+
if is_sentence_piece:
662+
# Additional preprocessing specific to the SentencePiece tokenizer.
663+
text = tokenization.preprocess_text(text, lower=do_lower_case)
664+
665+
return text.strip()
666+
667+
return processor_text_fn
668+
669+
636670
def main(_):
637-
tokenizer = tokenization.FullTokenizer(
638-
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
671+
if FLAGS.tokenization == "WordPiece":
672+
tokenizer = tokenization.FullTokenizer(
673+
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
674+
)
675+
processor_text_fn = get_processor_text_fn(False, FLAGS.do_lower_case)
676+
else:
677+
assert FLAGS.tokenization == "SentencePiece"
678+
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
679+
processor_text_fn = get_processor_text_fn(True, FLAGS.do_lower_case)
639680

640681
input_files = []
641682
for input_pattern in FLAGS.input_file.split(","):
@@ -647,9 +688,18 @@ def main(_):
647688

648689
rng = random.Random(FLAGS.random_seed)
649690
instances = create_training_instances(
650-
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
651-
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
652-
rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
691+
input_files,
692+
tokenizer,
693+
processor_text_fn,
694+
FLAGS.max_seq_length,
695+
FLAGS.dupe_factor,
696+
FLAGS.short_seq_prob,
697+
FLAGS.masked_lm_prob,
698+
FLAGS.max_predictions_per_seq,
699+
rng,
700+
FLAGS.do_whole_word_mask,
701+
FLAGS.max_ngram_size,
702+
)
653703

654704
output_files = FLAGS.output_file.split(",")
655705
logging.info("*** Writing to output files ***")

official/nlp/data/create_pretraining_data_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def assertTokens(self, input_tokens, output_tokens, masked_positions,
4343
continue
4444
self.fail("invalid mask value: {}".format(output_token))
4545

46-
def test_wordpieces_to_grams(self):
46+
def test_tokens_to_grams(self):
4747
tests = [
4848
(["That", "cone"], [(0, 1), (1, 2)]),
4949
(["That", "cone", "##s"], [(0, 1), (1, 3)]),
@@ -52,7 +52,7 @@ def test_wordpieces_to_grams(self):
5252
(["[CLS]", "Up", "##dog", "[SEP]", "Down"], [(1, 3), (4, 5)]),
5353
]
5454
for inp, expected in tests:
55-
output = cpd._wordpieces_to_grams(inp)
55+
output = cpd._tokens_to_grams(inp)
5656
self.assertEqual(expected, output)
5757

5858
def test_window(self):
@@ -81,8 +81,8 @@ def test_create_masked_lm_predictions(self):
8181
rng=rng,
8282
do_whole_word_mask=False,
8383
max_ngram_size=None))
84-
self.assertEqual(len(masked_positions), 3)
85-
self.assertEqual(len(masked_labels), 3)
84+
self.assertLen(masked_positions, 3)
85+
self.assertLen(masked_labels, 3)
8686
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
8787

8888
def test_create_masked_lm_predictions_whole_word(self):
@@ -100,8 +100,8 @@ def test_create_masked_lm_predictions_whole_word(self):
100100
max_ngram_size=None))
101101
# since we can't get exactly three tokens without breaking a word we
102102
# only take two.
103-
self.assertEqual(len(masked_positions), 2)
104-
self.assertEqual(len(masked_labels), 2)
103+
self.assertLen(masked_positions, 2)
104+
self.assertLen(masked_labels, 2)
105105
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
106106
# ensure that we took an entire word.
107107
self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])
@@ -119,8 +119,8 @@ def test_create_masked_lm_predictions_ngram(self):
119119
rng=rng,
120120
do_whole_word_mask=True,
121121
max_ngram_size=3))
122-
self.assertEqual(len(masked_positions), 76)
123-
self.assertEqual(len(masked_labels), 76)
122+
self.assertLen(masked_positions, 76)
123+
self.assertLen(masked_labels, 76)
124124
self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
125125

126126

0 commit comments

Comments
 (0)