Skip to content

Commit 782e39e

Browse files
saberkuntensorflower-gardener
authored andcommitted
Release the fast tokenizer bert wrapper on github: https://arxiv.org/abs/2012.15524
PiperOrigin-RevId: 416671104
1 parent c8bb9aa commit 782e39e

File tree

3 files changed

+246
-1
lines changed

3 files changed

+246
-1
lines changed

official/nlp/modeling/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
4747
from official.nlp.modeling.layers.text_layers import BertPackInputs
4848
from official.nlp.modeling.layers.text_layers import BertTokenizer
49+
from official.nlp.modeling.layers.text_layers import FastWordpieceBertTokenizer
4950
from official.nlp.modeling.layers.text_layers import SentencepieceTokenizer
5051
from official.nlp.modeling.layers.tn_transformer_expand_condense import TNTransformerExpandCondense
5152
from official.nlp.modeling.layers.transformer import *

official/nlp/modeling/layers/text_layers.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,22 @@
1313
# limitations under the License.
1414

1515
"""Keras Layers for BERT-specific preprocessing."""
16+
# pylint: disable=g-import-not-at-top
1617
from typing import Any, Dict, List, Optional, Union
1718

1819
from absl import logging
1920
import tensorflow as tf
2021

2122
try:
22-
import tensorflow_text as text # pylint: disable=g-import-not-at-top
23+
import tensorflow_text as text
24+
from tensorflow_text.python.ops import bert_tokenizer
2325
except ImportError:
2426
text = None
27+
bert_tokenizer = None
2528
except tf.errors.NotFoundError as e:
2629
logging.warn("Encountered error when importing tensorflow_text: %s", e)
2730
text = None
31+
bert_tokenizer = None
2832

2933

3034
def _check_if_tf_text_installed():
@@ -587,3 +591,139 @@ def _reshape(t):
587591
return dict(input_word_ids=_reshape(input_word_ids),
588592
input_mask=_reshape(input_mask),
589593
input_type_ids=_reshape(input_type_ids))
594+
595+
596+
class FastWordpieceBertTokenizer(tf.keras.layers.Layer):
597+
"""A bert tokenizer keras layer using text.FastWordpieceTokenizer.
598+
599+
See details: "Fast WordPiece Tokenization" (https://arxiv.org/abs/2012.15524)
600+
"""
601+
602+
def __init__(self,
603+
*,
604+
vocab_file: str,
605+
lower_case: bool,
606+
tokenize_with_offsets: bool = False,
607+
**kwargs):
608+
"""Initializes a FastWordpieceBertTokenizer layer.
609+
610+
Args:
611+
vocab_file: A Python string with the path of the vocabulary file. This is
612+
a text file with newline-separated wordpiece tokens. This layer loads
613+
a list of tokens from it to create text.FastWordpieceTokenizer.
614+
lower_case: A Python boolean forwarded to text.BasicTokenizer. If true,
615+
input text is converted to lower case (where applicable) before
616+
tokenization. This must be set to match the way in which the vocab_file
617+
was created.
618+
tokenize_with_offsets: A Python boolean. If true, this layer calls
619+
FastWordpieceTokenizer.tokenize_with_offsets() instead of plain
620+
.tokenize() and outputs a triple of (tokens, start_offsets,
621+
limit_offsets) insead of just tokens.
622+
**kwargs: standard arguments to Layer().
623+
"""
624+
super().__init__(**kwargs)
625+
logging.info("Initialize a FastWordpieceBertTokenizer.")
626+
self.tokenize_with_offsets = tokenize_with_offsets
627+
self._basic_tokenizer = bert_tokenizer.BasicTokenizer(lower_case=lower_case)
628+
629+
# Read the vocab file into a list of tokens to create `fast_wp_tokenizer`.
630+
self._vocab = [line.rstrip() for line in tf.io.gfile.GFile(vocab_file)]
631+
self._fast_wp_tokenizer = text.FastWordpieceTokenizer(
632+
vocab=self._vocab, token_out_type=tf.int32, no_pretokenization=True)
633+
self._special_tokens_dict = self._create_special_tokens_dict()
634+
635+
@property
636+
def vocab_size(self):
637+
return len(self._vocab)
638+
639+
def get_config(self):
640+
# Skip in tf.saved_model.save(); fail if called direcly.
641+
# We cannot just put the original, user-supplied vocab file name into
642+
# the config, because the path has to change as the SavedModel is copied
643+
# around.
644+
raise NotImplementedError("Not implemented yet.")
645+
646+
def get_special_tokens_dict(self):
647+
"""Returns dict of token ids, keyed by standard names for their purpose.
648+
649+
Returns:
650+
A dict from Python strings to Python integers. Each key is a standard
651+
name for a special token describing its use. (For example, "padding_id"
652+
is what BERT traditionally calls "[PAD]" but others may call "<pad>".)
653+
The corresponding value is the integer token id. If a special token
654+
is not found, its entry is omitted from the dict.
655+
656+
The supported keys and tokens are:
657+
* start_of_sequence_id: looked up from "[CLS]"
658+
* end_of_segment_id: looked up from "[SEP]"
659+
* padding_id: looked up form "[PAD]"
660+
* mask_id: looked up from "[MASK]"
661+
* vocab_size: one past the largest token id used
662+
"""
663+
return self._special_tokens_dict
664+
665+
def _create_special_tokens_dict(self):
666+
"""Creates dict of token ids, keyed by standard names for their purpose."""
667+
special_tokens = {"vocab_size": self.vocab_size}
668+
669+
def add_special_token(key, token):
670+
try:
671+
token_id = self._vocab.index(token)
672+
special_tokens[key] = token_id
673+
except ValueError:
674+
# Similar as nlp.modeling.layers.BertTokenizer, if a special token
675+
# is not found, its entry is omitted from the dict.
676+
logging.warning("Could not find %s as token \"%s\" in vocab file", key,
677+
token)
678+
679+
add_special_token("start_of_sequence_id", "[CLS]")
680+
add_special_token("end_of_segment_id", "[SEP]")
681+
add_special_token("padding_id", "[PAD]")
682+
add_special_token("mask_id", "[MASK]")
683+
return special_tokens
684+
685+
def _tokenize_with_offsets(self, text_input: tf.Tensor):
686+
tokens, begin, _ = self._basic_tokenizer.tokenize_with_offsets(text_input)
687+
wordpieces, wp_begin, wp_end = (
688+
self._fast_wp_tokenizer.tokenize_with_offsets(tokens))
689+
begin_expanded = tf.expand_dims(begin, axis=2)
690+
final_begin = begin_expanded + wp_begin
691+
final_end = begin_expanded + wp_end
692+
return wordpieces, final_begin, final_end
693+
694+
def _tokenize(self, text_input: tf.Tensor):
695+
tokens = self._basic_tokenizer.tokenize(text_input)
696+
return self._fast_wp_tokenizer.tokenize(tokens)
697+
698+
def call(self, inputs: tf.Tensor):
699+
"""Calls text.BertTokenizer on inputs.
700+
701+
Args:
702+
inputs: A string Tensor of shape [batch_size].
703+
704+
Returns:
705+
One or three of RaggedTensors if tokenize_with_offsets is False or True,
706+
respectively. These are
707+
tokens: A RaggedTensor of shape [batch_size, (words), (pieces_per_word)]
708+
and type int32. tokens[i,j,k] contains the k-th wordpiece of the
709+
j-th word in the i-th input.
710+
start_offsets, limit_offsets: If tokenize_with_offsets is True,
711+
RaggedTensors of type int64 with the same indices as tokens.
712+
Element [i,j,k] contains the byte offset at the start, or past the
713+
end, resp., for the k-th wordpiece of the j-th word in the i-th input.
714+
"""
715+
# Prepare to reshape the result to work around broken shape inference.
716+
batch_size = tf.shape(inputs)[0]
717+
718+
def _reshape(rt):
719+
values = rt.values
720+
row_splits = rt.row_splits
721+
row_splits = tf.reshape(row_splits, [batch_size + 1])
722+
return tf.RaggedTensor.from_row_splits(values, row_splits)
723+
724+
if self.tokenize_with_offsets:
725+
tokens, start_offsets, limit_offsets = self._tokenize_with_offsets(inputs)
726+
return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets)
727+
else:
728+
tokens = self._tokenize(inputs)
729+
return _reshape(tokens)

official/nlp/modeling/layers/text_layers_test.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,5 +442,109 @@ def test_special_tokens_dict(self):
442442
[1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
443443

444444

445+
# This test covers the in-process behavior of FastWordpieceBertTokenizer layer.
446+
class FastWordPieceBertTokenizerTest(tf.test.TestCase):
447+
448+
def _make_vocab_file(self, vocab, filename="vocab.txt"):
449+
path = os.path.join(
450+
tempfile.mkdtemp(dir=self.get_temp_dir()), # New subdir each time.
451+
filename)
452+
with tf.io.gfile.GFile(path, "w") as f:
453+
f.write("\n".join(vocab + [""]))
454+
return path
455+
456+
def test_uncased(self):
457+
vocab_file = self._make_vocab_file(
458+
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
459+
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
460+
vocab_file=vocab_file, lower_case=True)
461+
inputs = tf.constant(["abc def", "ABC DEF d"])
462+
token_ids = bert_tokenize(inputs)
463+
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
464+
[[6], [4, 5], [4]]]))
465+
bert_tokenize.tokenize_with_offsets = True
466+
token_ids_2, start_offsets, limit_offsets = bert_tokenize(inputs)
467+
self.assertAllEqual(token_ids, token_ids_2)
468+
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
469+
[[0], [4, 5], [8]]]))
470+
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
471+
[[3], [5, 7], [9]]]))
472+
self.assertEqual(bert_tokenize.vocab_size, 8)
473+
474+
# Repeat the above and test that case matters with lower_case=False.
475+
def test_cased(self):
476+
vocab_file = self._make_vocab_file(
477+
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "ABC"])
478+
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
479+
vocab_file=vocab_file, lower_case=False, tokenize_with_offsets=True)
480+
inputs = tf.constant(["abc def", "ABC DEF"])
481+
token_ids, start_offsets, limit_offsets = bert_tokenize(inputs)
482+
self.assertAllEqual(token_ids, tf.ragged.constant([[[6], [4, 5]],
483+
[[7], [1]]]))
484+
self.assertAllEqual(start_offsets, tf.ragged.constant([[[0], [4, 5]],
485+
[[0], [4]]]))
486+
self.assertAllEqual(limit_offsets, tf.ragged.constant([[[3], [5, 7]],
487+
[[3], [7]]]))
488+
489+
def test_special_tokens_complete(self):
490+
vocab_file = self._make_vocab_file(
491+
["foo", "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "xy"])
492+
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
493+
vocab_file=vocab_file, lower_case=True)
494+
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
495+
dict(padding_id=1,
496+
start_of_sequence_id=3,
497+
end_of_segment_id=4,
498+
mask_id=5,
499+
vocab_size=7))
500+
501+
def test_special_tokens_partial(self):
502+
# [UNK] token is required by fast wordpiece tokenizer.
503+
vocab_file = self._make_vocab_file(
504+
["[PAD]", "[CLS]", "[SEP]", "[UNK]"])
505+
bert_tokenize = text_layers.FastWordpieceBertTokenizer(
506+
vocab_file=vocab_file, lower_case=True)
507+
self.assertDictEqual(bert_tokenize.get_special_tokens_dict(),
508+
dict(padding_id=0,
509+
start_of_sequence_id=1,
510+
end_of_segment_id=2,
511+
vocab_size=4)) # No mask_id,
512+
513+
def test_special_tokens_in_estimator(self):
514+
"""Tests getting special tokens without an Eager init context."""
515+
vocab_file = self._make_vocab_file(
516+
["[PAD]", "[UNK]", "[CLS]", "[SEP]", "d", "##ef", "abc", "xy"])
517+
518+
def input_fn():
519+
with tf.init_scope():
520+
self.assertFalse(tf.executing_eagerly())
521+
# Build a preprocessing Model.
522+
sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
523+
bert_tokenizer = text_layers.FastWordpieceBertTokenizer(
524+
vocab_file=vocab_file, lower_case=True)
525+
special_tokens_dict = bert_tokenizer.get_special_tokens_dict()
526+
for k, v in special_tokens_dict.items():
527+
self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
528+
tokens = bert_tokenizer(sentences)
529+
packed_inputs = text_layers.BertPackInputs(
530+
4, special_tokens_dict=special_tokens_dict)(tokens)
531+
preprocessing = tf.keras.Model(sentences, packed_inputs)
532+
# Map the dataset.
533+
ds = tf.data.Dataset.from_tensors(
534+
(tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
535+
ds = ds.map(lambda features, labels: (preprocessing(features), labels))
536+
return ds
537+
538+
def model_fn(features, labels, mode):
539+
del labels # Unused.
540+
return tf.estimator.EstimatorSpec(mode=mode,
541+
predictions=features["input_word_ids"])
542+
543+
estimator = tf.estimator.Estimator(model_fn=model_fn)
544+
outputs = list(estimator.predict(input_fn))
545+
self.assertAllEqual(outputs, np.array([[2, 6, 3, 0],
546+
[2, 4, 5, 3]]))
547+
548+
445549
if __name__ == "__main__":
446550
tf.test.main()

0 commit comments

Comments
 (0)