Skip to content

Commit 9609524

Browse files
Add an encoder scaffold.
PiperOrigin-RevId: 286477560
1 parent 745e53a commit 9609524

File tree

6 files changed

+961
-12
lines changed

6 files changed

+961
-12
lines changed

official/nlp/modeling/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
1919
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
2020
from official.nlp.modeling.layers.position_embedding import PositionEmbedding
21+
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
2122
from official.nlp.modeling.layers.transformer import Transformer
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Keras layer that creates a self-attention mask."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
# from __future__ import google_type_annotations
20+
from __future__ import print_function
21+
22+
import tensorflow as tf
23+
from official.modeling import tf_utils
24+
25+
26+
@tf.keras.utils.register_keras_serializable(package='Text')
27+
class SelfAttentionMask(tf.keras.layers.Layer):
28+
"""Create 3D attention mask from a 2D tensor mask.
29+
30+
inputs[0]: from_tensor: 2D or 3D Tensor of shape
31+
[batch_size, from_seq_length, ...].
32+
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
33+
34+
Returns:
35+
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
36+
"""
37+
38+
def call(self, inputs):
39+
from_tensor = inputs[0]
40+
to_mask = inputs[1]
41+
from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
42+
batch_size = from_shape[0]
43+
from_seq_length = from_shape[1]
44+
45+
to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
46+
to_seq_length = to_shape[1]
47+
48+
to_mask = tf.cast(
49+
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
50+
dtype=from_tensor.dtype)
51+
52+
# We don't assume that `from_tensor` is a mask (although it could be). We
53+
# don't actually care if we attend *from* padding tokens (only *to* padding)
54+
# tokens so we create a tensor of all ones.
55+
#
56+
# `broadcast_ones` = [batch_size, from_seq_length, 1]
57+
broadcast_ones = tf.ones(
58+
shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
59+
60+
# Here we broadcast along two dimensions to create the mask.
61+
mask = broadcast_ones * to_mask
62+
63+
return mask

official/nlp/modeling/networks/albert_transformer_encoder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
2525
from official.modeling import activations
2626
from official.nlp.modeling import layers
27-
from official.nlp.modeling.networks import transformer_encoder
2827

2928

3029
@tf.keras.utils.register_keras_serializable(package='Text')
@@ -159,7 +158,7 @@ def __init__(self,
159158
embeddings = tf.cast(embeddings, tf.float16)
160159

161160
data = embeddings
162-
attention_mask = transformer_encoder.MakeAttentionMaskLayer()([data, mask])
161+
attention_mask = layers.SelfAttentionMask()([data, mask])
163162
shared_layer = layers.Transformer(
164163
num_attention_heads=num_attention_heads,
165164
intermediate_size=intermediate_size,
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Transformer-based text encoder network."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
# from __future__ import google_type_annotations
20+
from __future__ import print_function
21+
22+
import inspect
23+
import tensorflow as tf
24+
25+
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
26+
from official.nlp.modeling import layers
27+
28+
29+
@tf.keras.utils.register_keras_serializable(package='Text')
30+
class EncoderScaffold(network.Network):
31+
"""Bi-directional Transformer-based encoder network scaffold.
32+
33+
This network allows users to flexibly implement an encoder similar to the one
34+
described in "BERT: Pre-training of Deep Bidirectional Transformers for
35+
Language Understanding" (https://arxiv.org/abs/1810.04805).
36+
37+
In this network, users can choose to provide a custom embedding subnetwork
38+
(which will replace the standard embedding logic) and/or a custom hidden layer
39+
class (which will replace the Transformer instantiation in the encoder). For
40+
each of these custom injection points, users can pass either a class or a
41+
class instance. If a class is passed, that class will be instantiated using
42+
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance
43+
is passed, that instance will be invoked. (In the case of hidden_cls, the
44+
instance will be invoked 'num_hidden_instances' times.
45+
46+
If the hidden_cls is not overridden, a default transformer layer will be
47+
instantiated.
48+
49+
Attributes:
50+
num_output_classes: The output size of the classification layer.
51+
classification_layer_initializer: The initializer for the classification
52+
layer.
53+
classification_layer_dtype: The dtype for the classification layer.
54+
embedding_cls: The class or instance to use to embed the input data. This
55+
class or instance defines the inputs to this encoder. If embedding_cls is
56+
not set, a default embedding network (from the original BERT paper) will
57+
be created.
58+
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
59+
be instantiated. If embedding_cls is not set, a config dict must be
60+
passed to 'embedding_cfg' with the following values:
61+
"vocab_size": The size of the token vocabulary.
62+
"type_vocab_size": The size of the type vocabulary.
63+
"hidden_size": The hidden size for this encoder.
64+
"max_seq_length": The maximum sequence length for this encoder.
65+
"seq_length": The sequence length for this encoder.
66+
"initializer": The initializer for the embedding portion of this encoder.
67+
"dropout_rate": The dropout rate to apply before the encoding layers.
68+
"dtype": (Optional): The dtype of the embedding layers.
69+
embedding_data: A reference to the embedding weights that will be used to
70+
train the masked language model, if necessary. This is optional, and only
71+
needed if (1) you are overriding embedding_cls and (2) are doing standard
72+
pretraining.
73+
num_hidden_instances: The number of times to instantiate and/or invoke the
74+
hidden_cls.
75+
hidden_cls: The class or instance to encode the input data. If hidden_cls is
76+
not set, a KerasBERT transformer layer will be used as the encoder class.
77+
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
78+
instantiated. If hidden_cls is not set, a config dict must be passed to
79+
'hidden_cfg' with the following values:
80+
"num_attention_heads": The number of attention heads. The hidden size
81+
must be divisible by num_attention_heads.
82+
"intermediate_size": The intermediate size of the transformer.
83+
"intermediate_activation": The activation to apply in the transfomer.
84+
"dropout_rate": The overall dropout rate for the transformer layers.
85+
"attention_dropout_rate": The dropout rate for the attention layers.
86+
"kernel_initializer": The initializer for the transformer layers.
87+
"dtype": The dtype of the transformer.
88+
"""
89+
90+
def __init__(
91+
self,
92+
num_output_classes,
93+
classification_layer_initializer=tf.keras.initializers.TruncatedNormal(
94+
stddev=0.02),
95+
classification_layer_dtype=tf.float32,
96+
embedding_cls=None,
97+
embedding_cfg=None,
98+
embedding_data=None,
99+
num_hidden_instances=1,
100+
hidden_cls=layers.Transformer,
101+
hidden_cfg=None,
102+
**kwargs):
103+
print(embedding_cfg)
104+
self._self_setattr_tracking = False
105+
self._hidden_cls = hidden_cls
106+
self._hidden_cfg = hidden_cfg
107+
self._num_hidden_instances = num_hidden_instances
108+
self._num_output_classes = num_output_classes
109+
self._classification_layer_initializer = classification_layer_initializer
110+
self._embedding_cls = embedding_cls
111+
self._embedding_cfg = embedding_cfg
112+
self._embedding_data = embedding_data
113+
self._kwargs = kwargs
114+
115+
if embedding_cls:
116+
if inspect.isclass(embedding_cls):
117+
self._embedding_network = embedding_cls(embedding_cfg)
118+
else:
119+
self._embedding_network = embedding_cls
120+
inputs = self._embedding_network.inputs
121+
embeddings, mask = self._embedding_network(inputs)
122+
else:
123+
self._embedding_network = None
124+
word_ids = tf.keras.layers.Input(
125+
shape=(embedding_cfg['seq_length'],),
126+
dtype=tf.int32,
127+
name='input_word_ids')
128+
mask = tf.keras.layers.Input(
129+
shape=(embedding_cfg['seq_length'],),
130+
dtype=tf.int32,
131+
name='input_mask')
132+
type_ids = tf.keras.layers.Input(
133+
shape=(embedding_cfg['seq_length'],),
134+
dtype=tf.int32,
135+
name='input_type_ids')
136+
inputs = [word_ids, mask, type_ids]
137+
138+
self._embedding_layer = layers.OnDeviceEmbedding(
139+
vocab_size=embedding_cfg['vocab_size'],
140+
embedding_width=embedding_cfg['hidden_size'],
141+
initializer=embedding_cfg['initializer'],
142+
name='word_embeddings')
143+
144+
word_embeddings = self._embedding_layer(word_ids)
145+
146+
# Always uses dynamic slicing for simplicity.
147+
self._position_embedding_layer = layers.PositionEmbedding(
148+
initializer=embedding_cfg['initializer'],
149+
use_dynamic_slicing=True,
150+
max_sequence_length=embedding_cfg['max_seq_length'])
151+
position_embeddings = self._position_embedding_layer(word_embeddings)
152+
153+
type_embeddings = (
154+
layers.OnDeviceEmbedding(
155+
vocab_size=embedding_cfg['type_vocab_size'],
156+
embedding_width=embedding_cfg['hidden_size'],
157+
initializer=embedding_cfg['initializer'],
158+
use_one_hot=True,
159+
name='type_embeddings')(type_ids))
160+
161+
embeddings = tf.keras.layers.Add()(
162+
[word_embeddings, position_embeddings, type_embeddings])
163+
embeddings = (
164+
tf.keras.layers.LayerNormalization(
165+
name='embeddings/layer_norm',
166+
axis=-1,
167+
epsilon=1e-12,
168+
dtype=tf.float32)(embeddings))
169+
embeddings = (
170+
tf.keras.layers.Dropout(
171+
rate=embedding_cfg['dropout_rate'], dtype=tf.float32)(embeddings))
172+
173+
if embedding_cfg.get('dtype') == 'float16':
174+
embeddings = tf.cast(embeddings, tf.float16)
175+
176+
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
177+
data = embeddings
178+
179+
for _ in range(num_hidden_instances):
180+
if inspect.isclass(hidden_cls):
181+
layer = self._hidden_cls(**hidden_cfg)
182+
else:
183+
layer = self._hidden_cls
184+
data = layer([data, attention_mask])
185+
186+
first_token_tensor = (
187+
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data)
188+
)
189+
cls_output = tf.keras.layers.Dense(
190+
units=num_output_classes,
191+
activation='tanh',
192+
kernel_initializer=classification_layer_initializer,
193+
dtype=classification_layer_dtype,
194+
name='cls_transform')(
195+
first_token_tensor)
196+
197+
super(EncoderScaffold, self).__init__(
198+
inputs=inputs, outputs=[data, cls_output], **kwargs)
199+
200+
def get_config(self):
201+
config_dict = {
202+
'num_hidden_instances':
203+
self._num_hidden_instances,
204+
'num_output_classes':
205+
self._num_output_classes,
206+
'classification_layer_initializer':
207+
self._classification_layer_initializer,
208+
'embedding_cls':
209+
self._embedding_network,
210+
'embedding_cfg':
211+
self._embedding_cfg,
212+
'hidden_cfg':
213+
self._hidden_cfg,
214+
}
215+
if inspect.isclass(self._hidden_cls):
216+
config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name(
217+
self._hidden_cls)
218+
else:
219+
config_dict['hidden_cls'] = self._hidden_cls
220+
221+
config_dict.update(self._kwargs)
222+
return config_dict
223+
224+
@classmethod
225+
def from_config(cls, config, custom_objects=None):
226+
if 'hidden_cls_string' in config:
227+
config['hidden_cls'] = tf.keras.utils.get_registered_object(
228+
config['hidden_cls_string'], custom_objects=custom_objects)
229+
del config['hidden_cls_string']
230+
return cls(**config)
231+
232+
def get_embedding_table(self):
233+
if self._embedding_network is None:
234+
# In this case, we don't have a custom embedding network and can return
235+
# the standard embedding data.
236+
return self._embedding_layer.embeddings
237+
238+
if self._embedding_data is None:
239+
raise RuntimeError(('The EncoderScaffold %s does not have a reference '
240+
'to the embedding data. This is required when you '
241+
'pass a custom embedding network to the scaffold. '
242+
'It is also possible that you are trying to get '
243+
'embedding data from an embedding scaffold with a '
244+
'custom embedding network where the scaffold has '
245+
'been serialized and deserialized. Unfortunately, '
246+
'accessing custom embedding references after '
247+
'serialization is not yet supported.') % self.name)
248+
else:
249+
return self._embedding_data

0 commit comments

Comments
 (0)