Skip to content

Commit 885fda0

Browse files
Internal change
PiperOrigin-RevId: 424187808
1 parent 159697a commit 885fda0

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

official/nlp/modeling/layers/text_layers.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
"""Keras Layers for BERT-specific preprocessing."""
1616
# pylint: disable=g-import-not-at-top
17-
from typing import Any, Dict, List, Optional, Union
17+
from typing import Any, Dict, List, Mapping, Optional, Text, Union
1818

1919
from absl import logging
2020
import tensorflow as tf
@@ -71,8 +71,9 @@ class BertTokenizer(tf.keras.layers.Layer):
7171

7272
def __init__(self, *,
7373
vocab_file: str,
74-
lower_case: bool,
74+
lower_case: Optional[bool] = None,
7575
tokenize_with_offsets: bool = False,
76+
tokenizer_kwargs: Optional[Mapping[Text, Any]] = None,
7677
**kwargs):
7778
"""Initialize a `BertTokenizer` layer.
7879
@@ -81,15 +82,18 @@ def __init__(self, *,
8182
This is a text file with newline-separated wordpiece tokens.
8283
This layer initializes a lookup table from it that gets used with
8384
`text.BertTokenizer`.
84-
lower_case: A Python boolean forwarded to `text.BertTokenizer`.
85+
lower_case: Optional boolean forwarded to `text.BertTokenizer`.
8586
If true, input text is converted to lower case (where applicable)
8687
before tokenization. This must be set to match the way in which
87-
the `vocab_file` was created.
88+
the `vocab_file` was created. If passed, this overrides whatever value
89+
may have been passed in `tokenizer_kwargs`.
8890
tokenize_with_offsets: A Python boolean. If true, this layer calls
8991
`text.BertTokenizer.tokenize_with_offsets()` instead of plain
9092
`text.BertTokenizer.tokenize()` and outputs a triple of
9193
`(tokens, start_offsets, limit_offsets)`
9294
insead of just tokens.
95+
tokenizer_kwargs: Optional mapping with keyword arguments to forward to
96+
`text.BertTokenizer`'s constructor.
9397
**kwargs: Standard arguments to `Layer()`.
9498
9599
Raises:
@@ -111,8 +115,11 @@ def __init__(self, *,
111115
self._special_tokens_dict = self._create_special_tokens_dict(
112116
self._vocab_table, vocab_file)
113117
super().__init__(**kwargs)
114-
self._bert_tokenizer = text.BertTokenizer(
115-
self._vocab_table, lower_case=lower_case)
118+
tokenizer_kwargs = dict(tokenizer_kwargs or {})
119+
if lower_case is not None:
120+
tokenizer_kwargs["lower_case"] = lower_case
121+
self._bert_tokenizer = text.BertTokenizer(self._vocab_table,
122+
**tokenizer_kwargs)
116123

117124
@property
118125
def vocab_size(self):

0 commit comments

Comments
 (0)