Skip to content

Commit 9058543

Browse files
chaoyan1037tensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 481733792
1 parent d309fff commit 9058543

File tree

11 files changed

+849
-26
lines changed

11 files changed

+849
-26
lines changed

official/projects/vit/configs/image_classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ImageClassificationTask(cfg.TaskConfig):
7575
image_classification.ImageClassificationTask)
7676

7777

78-
@exp_factory.register_config_factory('deit_imagenet_pretrain')
78+
@exp_factory.register_config_factory('legacy_deit_imagenet_pretrain')
7979
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
8080
"""Image classification on imagenet with vision transformer."""
8181
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
156156
return config
157157

158158

159-
@exp_factory.register_config_factory('vit_imagenet_pretrain')
159+
@exp_factory.register_config_factory('legacy_vit_imagenet_pretrain')
160160
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
161161
"""Image classification on imagenet with vision transformer."""
162162
train_batch_size = 4096
@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
220220
return config
221221

222222

223-
@exp_factory.register_config_factory('vit_imagenet_finetune')
223+
@exp_factory.register_config_factory('legacy_vit_imagenet_finetune')
224224
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
225225
"""Image classification on imagenet with vision transformer."""
226226
train_batch_size = 512

official/projects/vit/modeling/vit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def __init__(self,
294294
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
295295

296296

297-
@factory.register_backbone_builder('vit')
297+
@factory.register_backbone_builder('legacy_vit')
298298
def build_vit(input_specs,
299299
backbone_config,
300300
norm_activation_config,

official/vision/configs/backbones.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,37 @@
1414

1515
"""Backbones configurations."""
1616
import dataclasses
17-
from typing import Optional, List
18-
19-
# Import libraries
17+
from typing import List, Optional, Tuple
2018

2119
from official.modeling import hyperparams
2220

2321

22+
@dataclasses.dataclass
23+
class Transformer(hyperparams.Config):
24+
"""Transformer config."""
25+
mlp_dim: int = 1
26+
num_heads: int = 1
27+
num_layers: int = 1
28+
attention_dropout_rate: float = 0.0
29+
dropout_rate: float = 0.1
30+
31+
32+
@dataclasses.dataclass
33+
class VisionTransformer(hyperparams.Config):
34+
"""VisionTransformer config."""
35+
model_name: str = 'vit-b16'
36+
# pylint: disable=line-too-long
37+
pooler: str = 'token' # 'token', 'gap' or 'none'. If set to 'token', an extra classification token is added to sequence.
38+
# pylint: enable=line-too-long
39+
representation_size: int = 0
40+
hidden_size: int = 1
41+
patch_size: int = 16
42+
transformer: Transformer = Transformer()
43+
init_stochastic_depth_rate: float = 0.0
44+
original_init: bool = True
45+
pos_embed_shape: Optional[Tuple[int, int]] = None
46+
47+
2448
@dataclasses.dataclass
2549
class ResNet(hyperparams.Config):
2650
"""ResNet config."""
@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig):
120144
spinenet_mobile: mobile spinenet backbone config.
121145
mobilenet: mobilenet backbone config.
122146
mobiledet: mobiledet backbone config.
147+
vit: vision transformer backbone config.
123148
"""
124149
type: Optional[str] = None
125150
resnet: ResNet = ResNet()
@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig):
130155
spinenet_mobile: SpineNetMobile = SpineNetMobile()
131156
mobilenet: MobileNet = MobileNet()
132157
mobiledet: MobileDet = MobileDet()
133-
158+
vit: VisionTransformer = VisionTransformer()

official/vision/configs/image_classification.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
402402
])
403403

404404
return config
405+
406+
407+
@exp_factory.register_config_factory('deit_imagenet_pretrain')
408+
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
409+
"""Image classification on imagenet with vision transformer."""
410+
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
411+
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
412+
label_smoothing = 0.1
413+
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
414+
config = cfg.ExperimentConfig(
415+
task=ImageClassificationTask(
416+
model=ImageClassificationModel(
417+
num_classes=1001,
418+
input_size=[224, 224, 3],
419+
kernel_initializer='zeros',
420+
backbone=backbones.Backbone(
421+
type='vit',
422+
vit=backbones.VisionTransformer(
423+
model_name='vit-b16',
424+
representation_size=768,
425+
init_stochastic_depth_rate=0.1,
426+
original_init=False,
427+
transformer=backbones.Transformer(
428+
dropout_rate=0.0, attention_dropout_rate=0.0)))),
429+
losses=Losses(
430+
l2_weight_decay=0.0,
431+
label_smoothing=label_smoothing,
432+
one_hot=False,
433+
soft_labels=True),
434+
train_data=DataConfig(
435+
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
436+
is_training=True,
437+
global_batch_size=train_batch_size,
438+
aug_type=common.Augmentation(
439+
type='randaug',
440+
randaug=common.RandAugment(
441+
magnitude=9, exclude_ops=['Cutout'])),
442+
mixup_and_cutmix=common.MixupAndCutmix(
443+
label_smoothing=label_smoothing)),
444+
validation_data=DataConfig(
445+
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
446+
is_training=False,
447+
global_batch_size=eval_batch_size)),
448+
trainer=cfg.TrainerConfig(
449+
steps_per_loop=steps_per_epoch,
450+
summary_interval=steps_per_epoch,
451+
checkpoint_interval=steps_per_epoch,
452+
train_steps=300 * steps_per_epoch,
453+
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
454+
validation_interval=steps_per_epoch,
455+
optimizer_config=optimization.OptimizationConfig({
456+
'optimizer': {
457+
'type': 'adamw',
458+
'adamw': {
459+
'weight_decay_rate': 0.05,
460+
'include_in_weight_decay': r'.*(kernel|weight):0$',
461+
'gradient_clip_norm': 0.0
462+
}
463+
},
464+
'learning_rate': {
465+
'type': 'cosine',
466+
'cosine': {
467+
'initial_learning_rate': 0.0005 * train_batch_size / 512,
468+
'decay_steps': 300 * steps_per_epoch,
469+
}
470+
},
471+
'warmup': {
472+
'type': 'linear',
473+
'linear': {
474+
'warmup_steps': 5 * steps_per_epoch,
475+
'warmup_learning_rate': 0
476+
}
477+
}
478+
})),
479+
restrictions=[
480+
'task.train_data.is_training != None',
481+
'task.validation_data.is_training != None'
482+
])
483+
484+
return config
485+
486+
487+
@exp_factory.register_config_factory('vit_imagenet_pretrain')
488+
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
489+
"""Image classification on imagenet with vision transformer."""
490+
train_batch_size = 4096
491+
eval_batch_size = 4096
492+
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
493+
config = cfg.ExperimentConfig(
494+
task=ImageClassificationTask(
495+
model=ImageClassificationModel(
496+
num_classes=1001,
497+
input_size=[224, 224, 3],
498+
kernel_initializer='zeros',
499+
backbone=backbones.Backbone(
500+
type='vit',
501+
vit=backbones.VisionTransformer(
502+
model_name='vit-b16', representation_size=768))),
503+
losses=Losses(l2_weight_decay=0.0),
504+
train_data=DataConfig(
505+
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
506+
is_training=True,
507+
global_batch_size=train_batch_size),
508+
validation_data=DataConfig(
509+
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
510+
is_training=False,
511+
global_batch_size=eval_batch_size)),
512+
trainer=cfg.TrainerConfig(
513+
steps_per_loop=steps_per_epoch,
514+
summary_interval=steps_per_epoch,
515+
checkpoint_interval=steps_per_epoch,
516+
train_steps=300 * steps_per_epoch,
517+
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
518+
validation_interval=steps_per_epoch,
519+
optimizer_config=optimization.OptimizationConfig({
520+
'optimizer': {
521+
'type': 'adamw',
522+
'adamw': {
523+
'weight_decay_rate': 0.3,
524+
'include_in_weight_decay': r'.*(kernel|weight):0$',
525+
'gradient_clip_norm': 0.0
526+
}
527+
},
528+
'learning_rate': {
529+
'type': 'cosine',
530+
'cosine': {
531+
'initial_learning_rate': 0.003 * train_batch_size / 4096,
532+
'decay_steps': 300 * steps_per_epoch,
533+
}
534+
},
535+
'warmup': {
536+
'type': 'linear',
537+
'linear': {
538+
'warmup_steps': 10000,
539+
'warmup_learning_rate': 0
540+
}
541+
}
542+
})),
543+
restrictions=[
544+
'task.train_data.is_training != None',
545+
'task.validation_data.is_training != None'
546+
])
547+
548+
return config
549+
550+
551+
@exp_factory.register_config_factory('vit_imagenet_finetune')
552+
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
553+
"""Image classification on imagenet with vision transformer."""
554+
train_batch_size = 512
555+
eval_batch_size = 512
556+
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
557+
config = cfg.ExperimentConfig(
558+
task=ImageClassificationTask(
559+
model=ImageClassificationModel(
560+
num_classes=1001,
561+
input_size=[384, 384, 3],
562+
backbone=backbones.Backbone(
563+
type='vit',
564+
vit=backbones.VisionTransformer(model_name='vit-b16'))),
565+
losses=Losses(l2_weight_decay=0.0),
566+
train_data=DataConfig(
567+
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
568+
is_training=True,
569+
global_batch_size=train_batch_size),
570+
validation_data=DataConfig(
571+
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
572+
is_training=False,
573+
global_batch_size=eval_batch_size)),
574+
trainer=cfg.TrainerConfig(
575+
steps_per_loop=steps_per_epoch,
576+
summary_interval=steps_per_epoch,
577+
checkpoint_interval=steps_per_epoch,
578+
train_steps=20000,
579+
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
580+
validation_interval=steps_per_epoch,
581+
optimizer_config=optimization.OptimizationConfig({
582+
'optimizer': {
583+
'type': 'sgd',
584+
'sgd': {
585+
'momentum': 0.9,
586+
'global_clipnorm': 1.0,
587+
}
588+
},
589+
'learning_rate': {
590+
'type': 'cosine',
591+
'cosine': {
592+
'initial_learning_rate': 0.003,
593+
'decay_steps': 20000,
594+
}
595+
}
596+
})),
597+
restrictions=[
598+
'task.train_data.is_training != None',
599+
'task.validation_data.is_training != None'
600+
])
601+
602+
return config

official/vision/configs/image_classification_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
2929
('resnet_imagenet',),
3030
('resnet_rs_imagenet',),
3131
('revnet_imagenet',),
32-
('mobilenet_imagenet'),
32+
('mobilenet_imagenet',),
33+
('deit_imagenet_pretrain',),
34+
('vit_imagenet_pretrain',),
35+
('vit_imagenet_finetune',),
3336
)
3437
def test_image_classification_configs(self, config_name):
3538
config = exp_factory.get_exp_config(config_name)

official/vision/modeling/backbones/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from official.vision.modeling.backbones.revnet import RevNet
2424
from official.vision.modeling.backbones.spinenet import SpineNet
2525
from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile
26+
from official.vision.modeling.backbones.vit import VisionTransformer

0 commit comments

Comments
 (0)