Skip to content

Commit 6e2129f

Browse files
chaoyan1037tensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 481938528
1 parent 09b3f5a commit 6e2129f

File tree

1 file changed

+9
-10
lines changed
  • official/projects/vit/modeling

1 file changed

+9
-10
lines changed

official/projects/vit/modeling/vit.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from official.vision.modeling.backbones import factory
2727
from official.vision.modeling.layers import nn_layers
2828

29-
3029
layers = tf.keras.layers
3130

3231

@@ -67,9 +66,7 @@ def build(self, inputs_shape):
6766
self.pos_embedding = self.add_weight(
6867
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
6968

70-
def _interpolate(self,
71-
pos_embedding: tf.Tensor,
72-
from_shape: Tuple[int, int],
69+
def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
7370
to_shape: Tuple[int, int]) -> tf.Tensor:
7471
"""Interpolates the positional embeddings."""
7572
logging.info('Interpolating postional embedding from length: %d to %d',
@@ -84,9 +81,10 @@ def call(self, inputs, inputs_positions=None):
8481
pos_embedding = self.pos_embedding
8582
# inputs.shape is (batch_size, seq_len, emb_dim).
8683
if inputs.shape[1] != pos_embedding.shape[1]:
87-
pos_embedding = self._interpolate(pos_embedding,
88-
from_shape=self.posemb_origin_shape,
89-
to_shape=self.posemb_target_shape)
84+
pos_embedding = self._interpolate(
85+
pos_embedding,
86+
from_shape=self.posemb_origin_shape,
87+
to_shape=self.posemb_target_shape)
9088
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
9189

9290
return inputs + pos_embedding
@@ -262,7 +260,8 @@ def __init__(self,
262260
class_name='TruncatedNormal', config=dict(stddev=.02)),
263261
init_stochastic_depth_rate=init_stochastic_depth_rate,
264262
pos_embed_origin_shape=pos_embed_shape,
265-
pos_embed_target_shape=pos_embed_target_shape)(x)
263+
pos_embed_target_shape=pos_embed_target_shape)(
264+
x)
266265

267266
if pooler == 'token':
268267
x = x[:, 0]
@@ -303,8 +302,8 @@ def build_vit(input_specs,
303302
del norm_activation_config
304303
backbone_type = backbone_config.type
305304
backbone_cfg = backbone_config.get()
306-
assert backbone_type == 'vit', (f'Inconsistent backbone type '
307-
f'{backbone_type}')
305+
assert backbone_type == 'legacy_vit', (f'Inconsistent backbone type '
306+
f'{backbone_type}')
308307
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
309308

310309
return VisionTransformer(

0 commit comments

Comments
 (0)