19
19
import random
20
20
21
21
# Import libraries
22
+
22
23
from absl import app
23
24
from absl import flags
24
25
from absl import logging
35
36
"output_file" , None ,
36
37
"Output TF example file (or comma-separated list of files)." )
37
38
38
- flags .DEFINE_string ("vocab_file" , None ,
39
- "The vocabulary file that the BERT model was trained on." )
39
+ flags .DEFINE_enum (
40
+ "tokenization" ,
41
+ "WordPiece" ,
42
+ ["WordPiece" , "SentencePiece" ],
43
+ "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
44
+ "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
45
+ "while ALBERT uses SentencePiece tokenizer." ,
46
+ )
47
+
48
+ flags .DEFINE_string (
49
+ "vocab_file" ,
50
+ None ,
51
+ "For WordPiece tokenization, the vocabulary file of the tokenizer." ,
52
+ )
53
+
54
+ flags .DEFINE_string (
55
+ "sp_model_file" ,
56
+ "" ,
57
+ "For SentencePiece tokenization, the path to the model of the tokenizer." ,
58
+ )
40
59
41
60
flags .DEFINE_bool (
42
61
"do_lower_case" , True ,
43
62
"Whether to lower case the input text. Should be True for uncased "
44
63
"models and False for cased models." )
45
64
46
65
flags .DEFINE_bool (
47
- "do_whole_word_mask" , False ,
48
- "Whether to use whole word masking rather than per-WordPiece masking." )
66
+ "do_whole_word_mask" ,
67
+ False ,
68
+ "Whether to use whole word masking rather than per-token masking." ,
69
+ )
49
70
50
71
flags .DEFINE_integer (
51
72
"max_ngram_size" , None ,
@@ -198,16 +219,19 @@ def create_float_feature(values):
198
219
return feature
199
220
200
221
201
- def create_training_instances (input_files ,
202
- tokenizer ,
203
- max_seq_length ,
204
- dupe_factor ,
205
- short_seq_prob ,
206
- masked_lm_prob ,
207
- max_predictions_per_seq ,
208
- rng ,
209
- do_whole_word_mask = False ,
210
- max_ngram_size = None ):
222
+ def create_training_instances (
223
+ input_files ,
224
+ tokenizer ,
225
+ processor_text_fn ,
226
+ max_seq_length ,
227
+ dupe_factor ,
228
+ short_seq_prob ,
229
+ masked_lm_prob ,
230
+ max_predictions_per_seq ,
231
+ rng ,
232
+ do_whole_word_mask = False ,
233
+ max_ngram_size = None ,
234
+ ):
211
235
"""Create `TrainingInstance`s from raw text."""
212
236
all_documents = [[]]
213
237
@@ -219,11 +243,8 @@ def create_training_instances(input_files,
219
243
# that the "next sentence prediction" task doesn't span between documents.
220
244
for input_file in input_files :
221
245
with tf .io .gfile .GFile (input_file , "rb" ) as reader :
222
- while True :
223
- line = tokenization .convert_to_unicode (reader .readline ())
224
- if not line :
225
- break
226
- line = line .strip ()
246
+ for line in reader :
247
+ line = processor_text_fn (line )
227
248
228
249
# Empty lines are used as document delimiters
229
250
if not line :
@@ -535,15 +556,16 @@ def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
535
556
return output_ngrams
536
557
537
558
538
- def _wordpieces_to_grams (tokens ):
559
+ def _tokens_to_grams (tokens ):
539
560
"""Reconstitue grams (words) from `tokens`.
540
561
541
562
E.g.,
542
563
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
543
564
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
544
565
545
566
Args:
546
- tokens: list of wordpieces
567
+ tokens: list of tokens (word pieces or sentence pieces).
568
+
547
569
Returns:
548
570
List of _Grams representing spans of whole words
549
571
(without "[CLS]" and "[SEP]").
@@ -570,7 +592,7 @@ def create_masked_lm_predictions(tokens, masked_lm_prob,
570
592
max_ngram_size = None ):
571
593
"""Creates the predictions for the masked LM objective."""
572
594
if do_whole_word_mask :
573
- grams = _wordpieces_to_grams (tokens )
595
+ grams = _tokens_to_grams (tokens )
574
596
else :
575
597
# Here we consider each token to be a word to allow for sub-word masking.
576
598
if max_ngram_size :
@@ -633,9 +655,28 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
633
655
trunc_tokens .pop ()
634
656
635
657
658
+ def get_processor_text_fn (is_sentence_piece , do_lower_case ):
659
+ def processor_text_fn (text ):
660
+ text = tokenization .convert_to_unicode (text )
661
+ if is_sentence_piece :
662
+ # Additional preprocessing specific to the SentencePiece tokenizer.
663
+ text = tokenization .preprocess_text (text , lower = do_lower_case )
664
+
665
+ return text .strip ()
666
+
667
+ return processor_text_fn
668
+
669
+
636
670
def main (_ ):
637
- tokenizer = tokenization .FullTokenizer (
638
- vocab_file = FLAGS .vocab_file , do_lower_case = FLAGS .do_lower_case )
671
+ if FLAGS .tokenization == "WordPiece" :
672
+ tokenizer = tokenization .FullTokenizer (
673
+ vocab_file = FLAGS .vocab_file , do_lower_case = FLAGS .do_lower_case
674
+ )
675
+ processor_text_fn = get_processor_text_fn (False , FLAGS .do_lower_case )
676
+ else :
677
+ assert FLAGS .tokenization == "SentencePiece"
678
+ tokenizer = tokenization .FullSentencePieceTokenizer (FLAGS .sp_model_file )
679
+ processor_text_fn = get_processor_text_fn (True , FLAGS .do_lower_case )
639
680
640
681
input_files = []
641
682
for input_pattern in FLAGS .input_file .split ("," ):
@@ -647,9 +688,18 @@ def main(_):
647
688
648
689
rng = random .Random (FLAGS .random_seed )
649
690
instances = create_training_instances (
650
- input_files , tokenizer , FLAGS .max_seq_length , FLAGS .dupe_factor ,
651
- FLAGS .short_seq_prob , FLAGS .masked_lm_prob , FLAGS .max_predictions_per_seq ,
652
- rng , FLAGS .do_whole_word_mask , FLAGS .max_ngram_size )
691
+ input_files ,
692
+ tokenizer ,
693
+ processor_text_fn ,
694
+ FLAGS .max_seq_length ,
695
+ FLAGS .dupe_factor ,
696
+ FLAGS .short_seq_prob ,
697
+ FLAGS .masked_lm_prob ,
698
+ FLAGS .max_predictions_per_seq ,
699
+ rng ,
700
+ FLAGS .do_whole_word_mask ,
701
+ FLAGS .max_ngram_size ,
702
+ )
653
703
654
704
output_files = FLAGS .output_file .split ("," )
655
705
logging .info ("*** Writing to output files ***" )
0 commit comments