|
| 1 | +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Transformer-based text encoder network.""" |
| 16 | + |
| 17 | +from __future__ import absolute_import |
| 18 | +from __future__ import division |
| 19 | +# from __future__ import google_type_annotations |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +import inspect |
| 23 | +import tensorflow as tf |
| 24 | + |
| 25 | +from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import |
| 26 | +from official.nlp.modeling import layers |
| 27 | + |
| 28 | + |
| 29 | +@tf.keras.utils.register_keras_serializable(package='Text') |
| 30 | +class EncoderScaffold(network.Network): |
| 31 | + """Bi-directional Transformer-based encoder network scaffold. |
| 32 | +
|
| 33 | + This network allows users to flexibly implement an encoder similar to the one |
| 34 | + described in "BERT: Pre-training of Deep Bidirectional Transformers for |
| 35 | + Language Understanding" (https://arxiv.org/abs/1810.04805). |
| 36 | +
|
| 37 | + In this network, users can choose to provide a custom embedding subnetwork |
| 38 | + (which will replace the standard embedding logic) and/or a custom hidden layer |
| 39 | + class (which will replace the Transformer instantiation in the encoder). For |
| 40 | + each of these custom injection points, users can pass either a class or a |
| 41 | + class instance. If a class is passed, that class will be instantiated using |
| 42 | + the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance |
| 43 | + is passed, that instance will be invoked. (In the case of hidden_cls, the |
| 44 | + instance will be invoked 'num_hidden_instances' times. |
| 45 | +
|
| 46 | + If the hidden_cls is not overridden, a default transformer layer will be |
| 47 | + instantiated. |
| 48 | +
|
| 49 | + Attributes: |
| 50 | + num_output_classes: The output size of the classification layer. |
| 51 | + classification_layer_initializer: The initializer for the classification |
| 52 | + layer. |
| 53 | + classification_layer_dtype: The dtype for the classification layer. |
| 54 | + embedding_cls: The class or instance to use to embed the input data. This |
| 55 | + class or instance defines the inputs to this encoder. If embedding_cls is |
| 56 | + not set, a default embedding network (from the original BERT paper) will |
| 57 | + be created. |
| 58 | + embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to |
| 59 | + be instantiated. If embedding_cls is not set, a config dict must be |
| 60 | + passed to 'embedding_cfg' with the following values: |
| 61 | + "vocab_size": The size of the token vocabulary. |
| 62 | + "type_vocab_size": The size of the type vocabulary. |
| 63 | + "hidden_size": The hidden size for this encoder. |
| 64 | + "max_seq_length": The maximum sequence length for this encoder. |
| 65 | + "seq_length": The sequence length for this encoder. |
| 66 | + "initializer": The initializer for the embedding portion of this encoder. |
| 67 | + "dropout_rate": The dropout rate to apply before the encoding layers. |
| 68 | + "dtype": (Optional): The dtype of the embedding layers. |
| 69 | + embedding_data: A reference to the embedding weights that will be used to |
| 70 | + train the masked language model, if necessary. This is optional, and only |
| 71 | + needed if (1) you are overriding embedding_cls and (2) are doing standard |
| 72 | + pretraining. |
| 73 | + num_hidden_instances: The number of times to instantiate and/or invoke the |
| 74 | + hidden_cls. |
| 75 | + hidden_cls: The class or instance to encode the input data. If hidden_cls is |
| 76 | + not set, a KerasBERT transformer layer will be used as the encoder class. |
| 77 | + hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be |
| 78 | + instantiated. If hidden_cls is not set, a config dict must be passed to |
| 79 | + 'hidden_cfg' with the following values: |
| 80 | + "num_attention_heads": The number of attention heads. The hidden size |
| 81 | + must be divisible by num_attention_heads. |
| 82 | + "intermediate_size": The intermediate size of the transformer. |
| 83 | + "intermediate_activation": The activation to apply in the transfomer. |
| 84 | + "dropout_rate": The overall dropout rate for the transformer layers. |
| 85 | + "attention_dropout_rate": The dropout rate for the attention layers. |
| 86 | + "kernel_initializer": The initializer for the transformer layers. |
| 87 | + "dtype": The dtype of the transformer. |
| 88 | + """ |
| 89 | + |
| 90 | + def __init__( |
| 91 | + self, |
| 92 | + num_output_classes, |
| 93 | + classification_layer_initializer=tf.keras.initializers.TruncatedNormal( |
| 94 | + stddev=0.02), |
| 95 | + classification_layer_dtype=tf.float32, |
| 96 | + embedding_cls=None, |
| 97 | + embedding_cfg=None, |
| 98 | + embedding_data=None, |
| 99 | + num_hidden_instances=1, |
| 100 | + hidden_cls=layers.Transformer, |
| 101 | + hidden_cfg=None, |
| 102 | + **kwargs): |
| 103 | + print(embedding_cfg) |
| 104 | + self._self_setattr_tracking = False |
| 105 | + self._hidden_cls = hidden_cls |
| 106 | + self._hidden_cfg = hidden_cfg |
| 107 | + self._num_hidden_instances = num_hidden_instances |
| 108 | + self._num_output_classes = num_output_classes |
| 109 | + self._classification_layer_initializer = classification_layer_initializer |
| 110 | + self._embedding_cls = embedding_cls |
| 111 | + self._embedding_cfg = embedding_cfg |
| 112 | + self._embedding_data = embedding_data |
| 113 | + self._kwargs = kwargs |
| 114 | + |
| 115 | + if embedding_cls: |
| 116 | + if inspect.isclass(embedding_cls): |
| 117 | + self._embedding_network = embedding_cls(embedding_cfg) |
| 118 | + else: |
| 119 | + self._embedding_network = embedding_cls |
| 120 | + inputs = self._embedding_network.inputs |
| 121 | + embeddings, mask = self._embedding_network(inputs) |
| 122 | + else: |
| 123 | + self._embedding_network = None |
| 124 | + word_ids = tf.keras.layers.Input( |
| 125 | + shape=(embedding_cfg['seq_length'],), |
| 126 | + dtype=tf.int32, |
| 127 | + name='input_word_ids') |
| 128 | + mask = tf.keras.layers.Input( |
| 129 | + shape=(embedding_cfg['seq_length'],), |
| 130 | + dtype=tf.int32, |
| 131 | + name='input_mask') |
| 132 | + type_ids = tf.keras.layers.Input( |
| 133 | + shape=(embedding_cfg['seq_length'],), |
| 134 | + dtype=tf.int32, |
| 135 | + name='input_type_ids') |
| 136 | + inputs = [word_ids, mask, type_ids] |
| 137 | + |
| 138 | + self._embedding_layer = layers.OnDeviceEmbedding( |
| 139 | + vocab_size=embedding_cfg['vocab_size'], |
| 140 | + embedding_width=embedding_cfg['hidden_size'], |
| 141 | + initializer=embedding_cfg['initializer'], |
| 142 | + name='word_embeddings') |
| 143 | + |
| 144 | + word_embeddings = self._embedding_layer(word_ids) |
| 145 | + |
| 146 | + # Always uses dynamic slicing for simplicity. |
| 147 | + self._position_embedding_layer = layers.PositionEmbedding( |
| 148 | + initializer=embedding_cfg['initializer'], |
| 149 | + use_dynamic_slicing=True, |
| 150 | + max_sequence_length=embedding_cfg['max_seq_length']) |
| 151 | + position_embeddings = self._position_embedding_layer(word_embeddings) |
| 152 | + |
| 153 | + type_embeddings = ( |
| 154 | + layers.OnDeviceEmbedding( |
| 155 | + vocab_size=embedding_cfg['type_vocab_size'], |
| 156 | + embedding_width=embedding_cfg['hidden_size'], |
| 157 | + initializer=embedding_cfg['initializer'], |
| 158 | + use_one_hot=True, |
| 159 | + name='type_embeddings')(type_ids)) |
| 160 | + |
| 161 | + embeddings = tf.keras.layers.Add()( |
| 162 | + [word_embeddings, position_embeddings, type_embeddings]) |
| 163 | + embeddings = ( |
| 164 | + tf.keras.layers.LayerNormalization( |
| 165 | + name='embeddings/layer_norm', |
| 166 | + axis=-1, |
| 167 | + epsilon=1e-12, |
| 168 | + dtype=tf.float32)(embeddings)) |
| 169 | + embeddings = ( |
| 170 | + tf.keras.layers.Dropout( |
| 171 | + rate=embedding_cfg['dropout_rate'], dtype=tf.float32)(embeddings)) |
| 172 | + |
| 173 | + if embedding_cfg.get('dtype') == 'float16': |
| 174 | + embeddings = tf.cast(embeddings, tf.float16) |
| 175 | + |
| 176 | + attention_mask = layers.SelfAttentionMask()([embeddings, mask]) |
| 177 | + data = embeddings |
| 178 | + |
| 179 | + for _ in range(num_hidden_instances): |
| 180 | + if inspect.isclass(hidden_cls): |
| 181 | + layer = self._hidden_cls(**hidden_cfg) |
| 182 | + else: |
| 183 | + layer = self._hidden_cls |
| 184 | + data = layer([data, attention_mask]) |
| 185 | + |
| 186 | + first_token_tensor = ( |
| 187 | + tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data) |
| 188 | + ) |
| 189 | + cls_output = tf.keras.layers.Dense( |
| 190 | + units=num_output_classes, |
| 191 | + activation='tanh', |
| 192 | + kernel_initializer=classification_layer_initializer, |
| 193 | + dtype=classification_layer_dtype, |
| 194 | + name='cls_transform')( |
| 195 | + first_token_tensor) |
| 196 | + |
| 197 | + super(EncoderScaffold, self).__init__( |
| 198 | + inputs=inputs, outputs=[data, cls_output], **kwargs) |
| 199 | + |
| 200 | + def get_config(self): |
| 201 | + config_dict = { |
| 202 | + 'num_hidden_instances': |
| 203 | + self._num_hidden_instances, |
| 204 | + 'num_output_classes': |
| 205 | + self._num_output_classes, |
| 206 | + 'classification_layer_initializer': |
| 207 | + self._classification_layer_initializer, |
| 208 | + 'embedding_cls': |
| 209 | + self._embedding_network, |
| 210 | + 'embedding_cfg': |
| 211 | + self._embedding_cfg, |
| 212 | + 'hidden_cfg': |
| 213 | + self._hidden_cfg, |
| 214 | + } |
| 215 | + if inspect.isclass(self._hidden_cls): |
| 216 | + config_dict['hidden_cls_string'] = tf.keras.utils.get_registered_name( |
| 217 | + self._hidden_cls) |
| 218 | + else: |
| 219 | + config_dict['hidden_cls'] = self._hidden_cls |
| 220 | + |
| 221 | + config_dict.update(self._kwargs) |
| 222 | + return config_dict |
| 223 | + |
| 224 | + @classmethod |
| 225 | + def from_config(cls, config, custom_objects=None): |
| 226 | + if 'hidden_cls_string' in config: |
| 227 | + config['hidden_cls'] = tf.keras.utils.get_registered_object( |
| 228 | + config['hidden_cls_string'], custom_objects=custom_objects) |
| 229 | + del config['hidden_cls_string'] |
| 230 | + return cls(**config) |
| 231 | + |
| 232 | + def get_embedding_table(self): |
| 233 | + if self._embedding_network is None: |
| 234 | + # In this case, we don't have a custom embedding network and can return |
| 235 | + # the standard embedding data. |
| 236 | + return self._embedding_layer.embeddings |
| 237 | + |
| 238 | + if self._embedding_data is None: |
| 239 | + raise RuntimeError(('The EncoderScaffold %s does not have a reference ' |
| 240 | + 'to the embedding data. This is required when you ' |
| 241 | + 'pass a custom embedding network to the scaffold. ' |
| 242 | + 'It is also possible that you are trying to get ' |
| 243 | + 'embedding data from an embedding scaffold with a ' |
| 244 | + 'custom embedding network where the scaffold has ' |
| 245 | + 'been serialized and deserialized. Unfortunately, ' |
| 246 | + 'accessing custom embedding references after ' |
| 247 | + 'serialization is not yet supported.') % self.name) |
| 248 | + else: |
| 249 | + return self._embedding_data |
0 commit comments