21
21
from official .vision .beta .modeling .layers import nn_layers
22
22
from official .vision .beta .projects .vit .modeling import nn_blocks
23
23
24
+
24
25
layers = tf .keras .layers
25
26
26
27
VIT_SPECS = {
@@ -121,6 +122,7 @@ def __init__(self,
121
122
inputs_positions = None ,
122
123
init_stochastic_depth_rate = 0.0 ,
123
124
kernel_initializer = 'glorot_uniform' ,
125
+ add_pos_embed = True ,
124
126
** kwargs ):
125
127
super ().__init__ (** kwargs )
126
128
self ._num_layers = num_layers
@@ -132,11 +134,13 @@ def __init__(self,
132
134
self ._inputs_positions = inputs_positions
133
135
self ._init_stochastic_depth_rate = init_stochastic_depth_rate
134
136
self ._kernel_initializer = kernel_initializer
137
+ self ._add_pos_embed = add_pos_embed
135
138
136
139
def build (self , input_shape ):
137
- self ._pos_embed = AddPositionEmbs (
138
- posemb_init = tf .keras .initializers .RandomNormal (stddev = 0.02 ),
139
- name = 'posembed_input' )
140
+ if self ._add_pos_embed :
141
+ self ._pos_embed = AddPositionEmbs (
142
+ posemb_init = tf .keras .initializers .RandomNormal (stddev = 0.02 ),
143
+ name = 'posembed_input' )
140
144
self ._dropout = layers .Dropout (rate = self ._dropout_rate )
141
145
142
146
self ._encoder_layers = []
@@ -160,7 +164,9 @@ def build(self, input_shape):
160
164
super ().build (input_shape )
161
165
162
166
def call (self , inputs , training = None ):
163
- x = self ._pos_embed (inputs , inputs_positions = self ._inputs_positions )
167
+ x = inputs
168
+ if self ._add_pos_embed :
169
+ x = self ._pos_embed (x , inputs_positions = self ._inputs_positions )
164
170
x = self ._dropout (x , training = training )
165
171
166
172
for encoder_layer in self ._encoder_layers :
0 commit comments