26
26
from official .vision .modeling .backbones import factory
27
27
from official .vision .modeling .layers import nn_layers
28
28
29
-
30
29
layers = tf .keras .layers
31
30
32
31
@@ -67,9 +66,7 @@ def build(self, inputs_shape):
67
66
self .pos_embedding = self .add_weight (
68
67
'pos_embedding' , pos_emb_shape , initializer = self .posemb_init )
69
68
70
- def _interpolate (self ,
71
- pos_embedding : tf .Tensor ,
72
- from_shape : Tuple [int , int ],
69
+ def _interpolate (self , pos_embedding : tf .Tensor , from_shape : Tuple [int , int ],
73
70
to_shape : Tuple [int , int ]) -> tf .Tensor :
74
71
"""Interpolates the positional embeddings."""
75
72
logging .info ('Interpolating postional embedding from length: %d to %d' ,
@@ -84,9 +81,10 @@ def call(self, inputs, inputs_positions=None):
84
81
pos_embedding = self .pos_embedding
85
82
# inputs.shape is (batch_size, seq_len, emb_dim).
86
83
if inputs .shape [1 ] != pos_embedding .shape [1 ]:
87
- pos_embedding = self ._interpolate (pos_embedding ,
88
- from_shape = self .posemb_origin_shape ,
89
- to_shape = self .posemb_target_shape )
84
+ pos_embedding = self ._interpolate (
85
+ pos_embedding ,
86
+ from_shape = self .posemb_origin_shape ,
87
+ to_shape = self .posemb_target_shape )
90
88
pos_embedding = tf .cast (pos_embedding , inputs .dtype )
91
89
92
90
return inputs + pos_embedding
@@ -262,7 +260,8 @@ def __init__(self,
262
260
class_name = 'TruncatedNormal' , config = dict (stddev = .02 )),
263
261
init_stochastic_depth_rate = init_stochastic_depth_rate ,
264
262
pos_embed_origin_shape = pos_embed_shape ,
265
- pos_embed_target_shape = pos_embed_target_shape )(x )
263
+ pos_embed_target_shape = pos_embed_target_shape )(
264
+ x )
266
265
267
266
if pooler == 'token' :
268
267
x = x [:, 0 ]
@@ -303,8 +302,8 @@ def build_vit(input_specs,
303
302
del norm_activation_config
304
303
backbone_type = backbone_config .type
305
304
backbone_cfg = backbone_config .get ()
306
- assert backbone_type == 'vit ' , (f'Inconsistent backbone type '
307
- f'{ backbone_type } ' )
305
+ assert backbone_type == 'legacy_vit ' , (f'Inconsistent backbone type '
306
+ f'{ backbone_type } ' )
308
307
backbone_cfg .override (VIT_SPECS [backbone_cfg .model_name ])
309
308
310
309
return VisionTransformer (
0 commit comments