Skip to content

Commit c60499b

Browse files
Implement the Vision TransformerScaffold which is a subclass from the NLP TransformerScaffold.
PiperOrigin-RevId: 480969429
1 parent ad48062 commit c60499b

File tree

2 files changed

+679
-0
lines changed

2 files changed

+679
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Keras-based Scaffold TransformerEncoder block for vision models.
16+
17+
This implementation is subclassed from NLP TransformerScaffold to support
18+
customized `attention_layer` and `feedforward_layer`. In addition, this
19+
implementation has a few features to better support vision use cases:
20+
1. `stochastic_depth_drop_rate` to supress model overfitting.
21+
2. `return_attention_scores`, optionally returns the attention output.
22+
3. `ffn_has_residual_connection`, clearly define whether feedforward network has
23+
residual connection or not to avoid ambiguity.
24+
"""
25+
from typing import List, Optional, Tuple, Union
26+
27+
import gin
28+
import tensorflow as tf
29+
30+
from official.nlp import modeling
31+
from official.vision.modeling.layers.nn_layers import StochasticDepth
32+
33+
34+
@tf.keras.utils.register_keras_serializable(package="Vision")
35+
@gin.configurable
36+
class TransformerScaffold(modeling.layers.TransformerScaffold):
37+
"""TransformerScaffold layer for vision applications.
38+
39+
This layer is a subclass of NLP TransformerScaffold:
40+
41+
Attributes:
42+
stochastic_depth_drop_rate: Drop rate for the residual connections.
43+
return_attention_scores: Optionally return the attention output.
44+
ffn_has_residual_connection: Whether the feedforward network has internal
45+
residual connection and layer norm. If False, the residual connection and
46+
the layer norm op are called inside TransformerScaffold.
47+
"""
48+
49+
def __init__(self,
50+
*args,
51+
stochastic_depth_drop_rate: float = 0.0,
52+
return_attention_scores: bool = False,
53+
ffn_has_residual_connection: bool = False,
54+
**kwargs):
55+
"""Initializes TransformerEncoderBlock."""
56+
super().__init__(*args, **kwargs)
57+
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
58+
self._return_attention_scores = return_attention_scores
59+
self._ffn_has_residual_connection = ffn_has_residual_connection
60+
61+
def build(self, input_shape: Union[tf.TensorShape, List[int]]):
62+
if self._stochastic_depth_drop_rate:
63+
self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate)
64+
else:
65+
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
66+
67+
super().build(input_shape)
68+
69+
def get_config(self):
70+
config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate,
71+
"return_attention_scores": self._return_attention_scores,
72+
"ffn_has_residual_connection": self._ffn_has_residual_connection}
73+
base_config = super().get_config()
74+
base_config.update(config)
75+
return base_config
76+
77+
def call(
78+
self,
79+
inputs: tf.Tensor,
80+
training: Optional[bool] = None
81+
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
82+
"""Transformer self-attention encoder block call."""
83+
if isinstance(inputs, (list, tuple)):
84+
if len(inputs) == 2:
85+
input_tensor, attention_mask = inputs
86+
key_value = None
87+
elif len(inputs) == 3:
88+
input_tensor, key_value, attention_mask = inputs
89+
else:
90+
raise ValueError("Unexpected inputs to %s with length at %d" %
91+
(self.__class__, len(inputs)))
92+
else:
93+
input_tensor, key_value, attention_mask = (inputs, None, None)
94+
95+
if key_value is None:
96+
key_value = input_tensor
97+
98+
if self._norm_first:
99+
source_tensor = input_tensor
100+
input_tensor = self._attention_layer_norm(input_tensor, training=training)
101+
102+
attention_layer_output = self._attention_layer(
103+
query=input_tensor,
104+
value=key_value,
105+
attention_mask=attention_mask,
106+
training=training,
107+
return_attention_scores=self._return_attention_scores)
108+
if isinstance(attention_layer_output, tuple):
109+
# `attention_layer_output` contains two tensors when
110+
# `return_attention_scores` is True.
111+
attention_output, attention_scores = attention_layer_output
112+
else:
113+
attention_output = attention_layer_output
114+
attention_output = self._attention_dropout(attention_output,
115+
training=training)
116+
117+
if self._norm_first:
118+
source_attention_output = source_tensor + self._stochastic_depth(
119+
attention_output, training=training)
120+
attention_output = self._output_layer_norm(source_attention_output,
121+
training=training)
122+
else:
123+
attention_output = self._attention_layer_norm(
124+
input_tensor +
125+
self._stochastic_depth(attention_output, training=training),
126+
training=training)
127+
128+
if self._feedforward_block is None:
129+
intermediate_output = self._intermediate_dense(attention_output)
130+
intermediate_output = self._intermediate_activation_layer(
131+
intermediate_output)
132+
layer_output = self._output_dense(intermediate_output, training=training)
133+
layer_output = self._output_dropout(layer_output, training=training)
134+
else:
135+
layer_output = self._feedforward_block(attention_output,
136+
training=training)
137+
138+
# During mixed precision training, layer norm output is always fp32 for now.
139+
# Casts fp32 for the subsequent add.
140+
layer_output = tf.cast(layer_output, tf.float32)
141+
142+
if self._norm_first:
143+
if self._ffn_has_residual_connection:
144+
raise ValueError(
145+
"In the case of `norm_first`, the residual connection should be"
146+
"done in the TransformerScaffold call function, not FFN's"
147+
"call function.")
148+
output = source_attention_output + self._stochastic_depth(
149+
layer_output, training=training)
150+
else:
151+
if self._ffn_has_residual_connection:
152+
output = self._stochastic_depth(layer_output, training=training)
153+
else:
154+
output = self._output_layer_norm(
155+
attention_output + self._stochastic_depth(
156+
layer_output, training=training))
157+
158+
if self._return_attention_scores:
159+
return output, attention_scores
160+
else:
161+
return output

0 commit comments

Comments
 (0)