Skip to content

Commit 0dadbbc

Browse files
Internal change
PiperOrigin-RevId: 420497751
1 parent 993dbf5 commit 0dadbbc

File tree

1 file changed

+10
-4
lines changed
  • official/vision/beta/projects/vit/modeling

1 file changed

+10
-4
lines changed

official/vision/beta/projects/vit/modeling/vit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from official.vision.beta.modeling.layers import nn_layers
2222
from official.vision.beta.projects.vit.modeling import nn_blocks
2323

24+
2425
layers = tf.keras.layers
2526

2627
VIT_SPECS = {
@@ -121,6 +122,7 @@ def __init__(self,
121122
inputs_positions=None,
122123
init_stochastic_depth_rate=0.0,
123124
kernel_initializer='glorot_uniform',
125+
add_pos_embed=True,
124126
**kwargs):
125127
super().__init__(**kwargs)
126128
self._num_layers = num_layers
@@ -132,11 +134,13 @@ def __init__(self,
132134
self._inputs_positions = inputs_positions
133135
self._init_stochastic_depth_rate = init_stochastic_depth_rate
134136
self._kernel_initializer = kernel_initializer
137+
self._add_pos_embed = add_pos_embed
135138

136139
def build(self, input_shape):
137-
self._pos_embed = AddPositionEmbs(
138-
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
139-
name='posembed_input')
140+
if self._add_pos_embed:
141+
self._pos_embed = AddPositionEmbs(
142+
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
143+
name='posembed_input')
140144
self._dropout = layers.Dropout(rate=self._dropout_rate)
141145

142146
self._encoder_layers = []
@@ -160,7 +164,9 @@ def build(self, input_shape):
160164
super().build(input_shape)
161165

162166
def call(self, inputs, training=None):
163-
x = self._pos_embed(inputs, inputs_positions=self._inputs_positions)
167+
x = inputs
168+
if self._add_pos_embed:
169+
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
164170
x = self._dropout(x, training=training)
165171

166172
for encoder_layer in self._encoder_layers:

0 commit comments

Comments
 (0)