|
| 1 | +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Keras-based Scaffold TransformerEncoder block for vision models. |
| 16 | +
|
| 17 | +This implementation is subclassed from NLP TransformerScaffold to support |
| 18 | +customized `attention_layer` and `feedforward_layer`. In addition, this |
| 19 | +implementation has a few features to better support vision use cases: |
| 20 | +1. `stochastic_depth_drop_rate` to supress model overfitting. |
| 21 | +2. `return_attention_scores`, optionally returns the attention output. |
| 22 | +3. `ffn_has_residual_connection`, clearly define whether feedforward network has |
| 23 | + residual connection or not to avoid ambiguity. |
| 24 | +""" |
| 25 | +from typing import List, Optional, Tuple, Union |
| 26 | + |
| 27 | +import gin |
| 28 | +import tensorflow as tf |
| 29 | + |
| 30 | +from official.nlp import modeling |
| 31 | +from official.vision.modeling.layers.nn_layers import StochasticDepth |
| 32 | + |
| 33 | + |
| 34 | +@tf.keras.utils.register_keras_serializable(package="Vision") |
| 35 | +@gin.configurable |
| 36 | +class TransformerScaffold(modeling.layers.TransformerScaffold): |
| 37 | + """TransformerScaffold layer for vision applications. |
| 38 | +
|
| 39 | + This layer is a subclass of NLP TransformerScaffold: |
| 40 | +
|
| 41 | + Attributes: |
| 42 | + stochastic_depth_drop_rate: Drop rate for the residual connections. |
| 43 | + return_attention_scores: Optionally return the attention output. |
| 44 | + ffn_has_residual_connection: Whether the feedforward network has internal |
| 45 | + residual connection and layer norm. If False, the residual connection and |
| 46 | + the layer norm op are called inside TransformerScaffold. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__(self, |
| 50 | + *args, |
| 51 | + stochastic_depth_drop_rate: float = 0.0, |
| 52 | + return_attention_scores: bool = False, |
| 53 | + ffn_has_residual_connection: bool = False, |
| 54 | + **kwargs): |
| 55 | + """Initializes TransformerEncoderBlock.""" |
| 56 | + super().__init__(*args, **kwargs) |
| 57 | + self._stochastic_depth_drop_rate = stochastic_depth_drop_rate |
| 58 | + self._return_attention_scores = return_attention_scores |
| 59 | + self._ffn_has_residual_connection = ffn_has_residual_connection |
| 60 | + |
| 61 | + def build(self, input_shape: Union[tf.TensorShape, List[int]]): |
| 62 | + if self._stochastic_depth_drop_rate: |
| 63 | + self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate) |
| 64 | + else: |
| 65 | + self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x) |
| 66 | + |
| 67 | + super().build(input_shape) |
| 68 | + |
| 69 | + def get_config(self): |
| 70 | + config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate, |
| 71 | + "return_attention_scores": self._return_attention_scores, |
| 72 | + "ffn_has_residual_connection": self._ffn_has_residual_connection} |
| 73 | + base_config = super().get_config() |
| 74 | + base_config.update(config) |
| 75 | + return base_config |
| 76 | + |
| 77 | + def call( |
| 78 | + self, |
| 79 | + inputs: tf.Tensor, |
| 80 | + training: Optional[bool] = None |
| 81 | + ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]: |
| 82 | + """Transformer self-attention encoder block call.""" |
| 83 | + if isinstance(inputs, (list, tuple)): |
| 84 | + if len(inputs) == 2: |
| 85 | + input_tensor, attention_mask = inputs |
| 86 | + key_value = None |
| 87 | + elif len(inputs) == 3: |
| 88 | + input_tensor, key_value, attention_mask = inputs |
| 89 | + else: |
| 90 | + raise ValueError("Unexpected inputs to %s with length at %d" % |
| 91 | + (self.__class__, len(inputs))) |
| 92 | + else: |
| 93 | + input_tensor, key_value, attention_mask = (inputs, None, None) |
| 94 | + |
| 95 | + if key_value is None: |
| 96 | + key_value = input_tensor |
| 97 | + |
| 98 | + if self._norm_first: |
| 99 | + source_tensor = input_tensor |
| 100 | + input_tensor = self._attention_layer_norm(input_tensor, training=training) |
| 101 | + |
| 102 | + attention_layer_output = self._attention_layer( |
| 103 | + query=input_tensor, |
| 104 | + value=key_value, |
| 105 | + attention_mask=attention_mask, |
| 106 | + training=training, |
| 107 | + return_attention_scores=self._return_attention_scores) |
| 108 | + if isinstance(attention_layer_output, tuple): |
| 109 | + # `attention_layer_output` contains two tensors when |
| 110 | + # `return_attention_scores` is True. |
| 111 | + attention_output, attention_scores = attention_layer_output |
| 112 | + else: |
| 113 | + attention_output = attention_layer_output |
| 114 | + attention_output = self._attention_dropout(attention_output, |
| 115 | + training=training) |
| 116 | + |
| 117 | + if self._norm_first: |
| 118 | + source_attention_output = source_tensor + self._stochastic_depth( |
| 119 | + attention_output, training=training) |
| 120 | + attention_output = self._output_layer_norm(source_attention_output, |
| 121 | + training=training) |
| 122 | + else: |
| 123 | + attention_output = self._attention_layer_norm( |
| 124 | + input_tensor + |
| 125 | + self._stochastic_depth(attention_output, training=training), |
| 126 | + training=training) |
| 127 | + |
| 128 | + if self._feedforward_block is None: |
| 129 | + intermediate_output = self._intermediate_dense(attention_output) |
| 130 | + intermediate_output = self._intermediate_activation_layer( |
| 131 | + intermediate_output) |
| 132 | + layer_output = self._output_dense(intermediate_output, training=training) |
| 133 | + layer_output = self._output_dropout(layer_output, training=training) |
| 134 | + else: |
| 135 | + layer_output = self._feedforward_block(attention_output, |
| 136 | + training=training) |
| 137 | + |
| 138 | + # During mixed precision training, layer norm output is always fp32 for now. |
| 139 | + # Casts fp32 for the subsequent add. |
| 140 | + layer_output = tf.cast(layer_output, tf.float32) |
| 141 | + |
| 142 | + if self._norm_first: |
| 143 | + if self._ffn_has_residual_connection: |
| 144 | + raise ValueError( |
| 145 | + "In the case of `norm_first`, the residual connection should be" |
| 146 | + "done in the TransformerScaffold call function, not FFN's" |
| 147 | + "call function.") |
| 148 | + output = source_attention_output + self._stochastic_depth( |
| 149 | + layer_output, training=training) |
| 150 | + else: |
| 151 | + if self._ffn_has_residual_connection: |
| 152 | + output = self._stochastic_depth(layer_output, training=training) |
| 153 | + else: |
| 154 | + output = self._output_layer_norm( |
| 155 | + attention_output + self._stochastic_depth( |
| 156 | + layer_output, training=training)) |
| 157 | + |
| 158 | + if self._return_attention_scores: |
| 159 | + return output, attention_scores |
| 160 | + else: |
| 161 | + return output |
0 commit comments