|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | """Keras Layers for BERT-specific preprocessing."""
|
| 16 | +# pylint: disable=g-import-not-at-top |
16 | 17 | from typing import Any, Dict, List, Optional, Union
|
17 | 18 |
|
18 | 19 | from absl import logging
|
19 | 20 | import tensorflow as tf
|
20 | 21 |
|
21 | 22 | try:
|
22 |
| - import tensorflow_text as text # pylint: disable=g-import-not-at-top |
| 23 | + import tensorflow_text as text |
| 24 | + from tensorflow_text.python.ops import bert_tokenizer |
23 | 25 | except ImportError:
|
24 | 26 | text = None
|
| 27 | + bert_tokenizer = None |
25 | 28 | except tf.errors.NotFoundError as e:
|
26 | 29 | logging.warn("Encountered error when importing tensorflow_text: %s", e)
|
27 | 30 | text = None
|
| 31 | + bert_tokenizer = None |
28 | 32 |
|
29 | 33 |
|
30 | 34 | def _check_if_tf_text_installed():
|
@@ -587,3 +591,139 @@ def _reshape(t):
|
587 | 591 | return dict(input_word_ids=_reshape(input_word_ids),
|
588 | 592 | input_mask=_reshape(input_mask),
|
589 | 593 | input_type_ids=_reshape(input_type_ids))
|
| 594 | + |
| 595 | + |
| 596 | +class FastWordpieceBertTokenizer(tf.keras.layers.Layer): |
| 597 | + """A bert tokenizer keras layer using text.FastWordpieceTokenizer. |
| 598 | +
|
| 599 | + See details: "Fast WordPiece Tokenization" (https://arxiv.org/abs/2012.15524) |
| 600 | + """ |
| 601 | + |
| 602 | + def __init__(self, |
| 603 | + *, |
| 604 | + vocab_file: str, |
| 605 | + lower_case: bool, |
| 606 | + tokenize_with_offsets: bool = False, |
| 607 | + **kwargs): |
| 608 | + """Initializes a FastWordpieceBertTokenizer layer. |
| 609 | +
|
| 610 | + Args: |
| 611 | + vocab_file: A Python string with the path of the vocabulary file. This is |
| 612 | + a text file with newline-separated wordpiece tokens. This layer loads |
| 613 | + a list of tokens from it to create text.FastWordpieceTokenizer. |
| 614 | + lower_case: A Python boolean forwarded to text.BasicTokenizer. If true, |
| 615 | + input text is converted to lower case (where applicable) before |
| 616 | + tokenization. This must be set to match the way in which the vocab_file |
| 617 | + was created. |
| 618 | + tokenize_with_offsets: A Python boolean. If true, this layer calls |
| 619 | + FastWordpieceTokenizer.tokenize_with_offsets() instead of plain |
| 620 | + .tokenize() and outputs a triple of (tokens, start_offsets, |
| 621 | + limit_offsets) insead of just tokens. |
| 622 | + **kwargs: standard arguments to Layer(). |
| 623 | + """ |
| 624 | + super().__init__(**kwargs) |
| 625 | + logging.info("Initialize a FastWordpieceBertTokenizer.") |
| 626 | + self.tokenize_with_offsets = tokenize_with_offsets |
| 627 | + self._basic_tokenizer = bert_tokenizer.BasicTokenizer(lower_case=lower_case) |
| 628 | + |
| 629 | + # Read the vocab file into a list of tokens to create `fast_wp_tokenizer`. |
| 630 | + self._vocab = [line.rstrip() for line in tf.io.gfile.GFile(vocab_file)] |
| 631 | + self._fast_wp_tokenizer = text.FastWordpieceTokenizer( |
| 632 | + vocab=self._vocab, token_out_type=tf.int32, no_pretokenization=True) |
| 633 | + self._special_tokens_dict = self._create_special_tokens_dict() |
| 634 | + |
| 635 | + @property |
| 636 | + def vocab_size(self): |
| 637 | + return len(self._vocab) |
| 638 | + |
| 639 | + def get_config(self): |
| 640 | + # Skip in tf.saved_model.save(); fail if called direcly. |
| 641 | + # We cannot just put the original, user-supplied vocab file name into |
| 642 | + # the config, because the path has to change as the SavedModel is copied |
| 643 | + # around. |
| 644 | + raise NotImplementedError("Not implemented yet.") |
| 645 | + |
| 646 | + def get_special_tokens_dict(self): |
| 647 | + """Returns dict of token ids, keyed by standard names for their purpose. |
| 648 | +
|
| 649 | + Returns: |
| 650 | + A dict from Python strings to Python integers. Each key is a standard |
| 651 | + name for a special token describing its use. (For example, "padding_id" |
| 652 | + is what BERT traditionally calls "[PAD]" but others may call "<pad>".) |
| 653 | + The corresponding value is the integer token id. If a special token |
| 654 | + is not found, its entry is omitted from the dict. |
| 655 | +
|
| 656 | + The supported keys and tokens are: |
| 657 | + * start_of_sequence_id: looked up from "[CLS]" |
| 658 | + * end_of_segment_id: looked up from "[SEP]" |
| 659 | + * padding_id: looked up form "[PAD]" |
| 660 | + * mask_id: looked up from "[MASK]" |
| 661 | + * vocab_size: one past the largest token id used |
| 662 | + """ |
| 663 | + return self._special_tokens_dict |
| 664 | + |
| 665 | + def _create_special_tokens_dict(self): |
| 666 | + """Creates dict of token ids, keyed by standard names for their purpose.""" |
| 667 | + special_tokens = {"vocab_size": self.vocab_size} |
| 668 | + |
| 669 | + def add_special_token(key, token): |
| 670 | + try: |
| 671 | + token_id = self._vocab.index(token) |
| 672 | + special_tokens[key] = token_id |
| 673 | + except ValueError: |
| 674 | + # Similar as nlp.modeling.layers.BertTokenizer, if a special token |
| 675 | + # is not found, its entry is omitted from the dict. |
| 676 | + logging.warning("Could not find %s as token \"%s\" in vocab file", key, |
| 677 | + token) |
| 678 | + |
| 679 | + add_special_token("start_of_sequence_id", "[CLS]") |
| 680 | + add_special_token("end_of_segment_id", "[SEP]") |
| 681 | + add_special_token("padding_id", "[PAD]") |
| 682 | + add_special_token("mask_id", "[MASK]") |
| 683 | + return special_tokens |
| 684 | + |
| 685 | + def _tokenize_with_offsets(self, text_input: tf.Tensor): |
| 686 | + tokens, begin, _ = self._basic_tokenizer.tokenize_with_offsets(text_input) |
| 687 | + wordpieces, wp_begin, wp_end = ( |
| 688 | + self._fast_wp_tokenizer.tokenize_with_offsets(tokens)) |
| 689 | + begin_expanded = tf.expand_dims(begin, axis=2) |
| 690 | + final_begin = begin_expanded + wp_begin |
| 691 | + final_end = begin_expanded + wp_end |
| 692 | + return wordpieces, final_begin, final_end |
| 693 | + |
| 694 | + def _tokenize(self, text_input: tf.Tensor): |
| 695 | + tokens = self._basic_tokenizer.tokenize(text_input) |
| 696 | + return self._fast_wp_tokenizer.tokenize(tokens) |
| 697 | + |
| 698 | + def call(self, inputs: tf.Tensor): |
| 699 | + """Calls text.BertTokenizer on inputs. |
| 700 | +
|
| 701 | + Args: |
| 702 | + inputs: A string Tensor of shape [batch_size]. |
| 703 | +
|
| 704 | + Returns: |
| 705 | + One or three of RaggedTensors if tokenize_with_offsets is False or True, |
| 706 | + respectively. These are |
| 707 | + tokens: A RaggedTensor of shape [batch_size, (words), (pieces_per_word)] |
| 708 | + and type int32. tokens[i,j,k] contains the k-th wordpiece of the |
| 709 | + j-th word in the i-th input. |
| 710 | + start_offsets, limit_offsets: If tokenize_with_offsets is True, |
| 711 | + RaggedTensors of type int64 with the same indices as tokens. |
| 712 | + Element [i,j,k] contains the byte offset at the start, or past the |
| 713 | + end, resp., for the k-th wordpiece of the j-th word in the i-th input. |
| 714 | + """ |
| 715 | + # Prepare to reshape the result to work around broken shape inference. |
| 716 | + batch_size = tf.shape(inputs)[0] |
| 717 | + |
| 718 | + def _reshape(rt): |
| 719 | + values = rt.values |
| 720 | + row_splits = rt.row_splits |
| 721 | + row_splits = tf.reshape(row_splits, [batch_size + 1]) |
| 722 | + return tf.RaggedTensor.from_row_splits(values, row_splits) |
| 723 | + |
| 724 | + if self.tokenize_with_offsets: |
| 725 | + tokens, start_offsets, limit_offsets = self._tokenize_with_offsets(inputs) |
| 726 | + return _reshape(tokens), _reshape(start_offsets), _reshape(limit_offsets) |
| 727 | + else: |
| 728 | + tokens = self._tokenize(inputs) |
| 729 | + return _reshape(tokens) |
0 commit comments