Skip to content

Commit 011a580

Browse files
author
miguelCalado
committed
Added: VGG-16 configurations and model
1 parent d58be67 commit 011a580

File tree

7 files changed

+345
-0
lines changed

7 files changed

+345
-0
lines changed

official/vision/image_classification/classifier_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@
3636
from official.vision.image_classification.efficientnet import efficientnet_model
3737
from official.vision.image_classification.resnet import common
3838
from official.vision.image_classification.resnet import resnet_model
39+
from official.vision.image_classification.vgg16 import vgg_model
3940

4041

4142
def get_models() -> Mapping[str, tf.keras.Model]:
4243
"""Returns the mapping from model type name to Keras model."""
4344
return {
4445
'efficientnet': efficientnet_model.EfficientNet.from_name,
4546
'resnet': resnet_model.resnet50,
47+
'vgg': vgg_model.vgg16,
4648
}
4749

4850

official/vision/image_classification/classifier_trainer_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
5353
model=[
5454
'efficientnet',
5555
'resnet',
56+
'vgg',
5657
],
5758
dataset=[
5859
'imagenet',
@@ -149,6 +150,7 @@ def test_end_to_end_train_and_eval(self, distribution, model, dataset):
149150
model=[
150151
'efficientnet',
151152
'resnet',
153+
'vgg',
152154
],
153155
dataset='imagenet',
154156
dtype='float16',
@@ -193,6 +195,7 @@ def test_gpu_train(self, distribution, model, dataset, dtype):
193195
model=[
194196
'efficientnet',
195197
'resnet',
198+
'vgg',
196199
],
197200
dataset='imagenet',
198201
dtype='bfloat16',

official/vision/image_classification/configs/configs.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,44 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
9191
epochs_between_evals=1, steps=None)
9292
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
9393

94+
@dataclasses.dataclass
95+
class VGGImagenetConfig(base_configs.ExperimentConfig):
96+
"""Base configuration to train vgg-16 on ImageNet."""
97+
export: base_configs.ExportConfig = base_configs.ExportConfig()
98+
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
99+
train_dataset: dataset_factory.DatasetConfig = \
100+
dataset_factory.ImageNetConfig(split='train',
101+
one_hot=False,
102+
mean_subtract=True,
103+
standardize=True)
104+
validation_dataset: dataset_factory.DatasetConfig = \
105+
dataset_factory.ImageNetConfig(split='validation',
106+
one_hot=False,
107+
mean_subtract=True,
108+
standardize=True)
109+
train: base_configs.TrainConfig = base_configs.TrainConfig(
110+
resume_checkpoint=True,
111+
epochs=90,
112+
steps=None,
113+
callbacks=base_configs.CallbacksConfig(
114+
enable_checkpoint_and_export=True, enable_tensorboard=True),
115+
metrics=['accuracy', 'top_5'],
116+
time_history=base_configs.TimeHistoryConfig(log_steps=100),
117+
tensorboard=base_configs.TensorBoardConfig(
118+
track_lr=True, write_model_weights=False),
119+
set_epoch_loop=False)
120+
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
121+
epochs_between_evals=1, steps=None)
122+
model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
123+
94124

95125
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
96126
"""Given model and dataset names, return the ExperimentConfig."""
97127
dataset_model_config_map = {
98128
'imagenet': {
99129
'efficientnet': EfficientNetImageNetConfig(),
100130
'resnet': ResNetImagenetConfig(),
131+
'vgg': VGGImagenetConfig()
101132
}
102133
}
103134
try:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Training configuration for VGG-16 trained on ImageNet on GPUs.
2+
# Reaches > 72.8% within 90 epochs.
3+
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
4+
runtime:
5+
distribution_strategy: 'mirrored'
6+
num_gpus: 1
7+
batchnorm_spatial_persistent: True
8+
train_dataset:
9+
name: 'imagenet2012'
10+
data_dir: null
11+
builder: 'records'
12+
split: 'train'
13+
image_size: 224
14+
num_classes: 1000
15+
num_examples: 1281167
16+
batch_size: 128
17+
use_per_replica_batch_size: True
18+
dtype: 'float32'
19+
mean_subtract: True
20+
standardize: True
21+
validation_dataset:
22+
name: 'imagenet2012'
23+
data_dir: null
24+
builder: 'records'
25+
split: 'validation'
26+
image_size: 224
27+
num_classes: 1000
28+
num_examples: 50000
29+
batch_size: 128
30+
use_per_replica_batch_size: True
31+
dtype: 'float32'
32+
mean_subtract: True
33+
standardize: True
34+
model:
35+
name: 'vgg'
36+
optimizer:
37+
name: 'momentum'
38+
momentum: 0.9
39+
epsilon: 0.001
40+
loss:
41+
label_smoothing: 0.0
42+
train:
43+
resume_checkpoint: True
44+
epochs: 90
45+
evaluation:
46+
epochs_between_evals: 1
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Configuration definitions for VGG losses, learning rates, and optimizers."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import dataclasses
22+
23+
from official.modeling.hyperparams import base_config
24+
from official.vision.image_classification.configs import base_configs
25+
26+
27+
@dataclasses.dataclass
28+
class VGGModelConfig(base_configs.ModelConfig):
29+
"""Configuration for the VGG model."""
30+
name: str = 'VGG'
31+
num_classes: int = 1000
32+
model_params: base_config.Config = dataclasses.field(
33+
default_factory=lambda: {
34+
'num_classes': 1000,
35+
'batch_size': None,
36+
'use_l2_regularizer': True
37+
})
38+
loss: base_configs.LossConfig = base_configs.LossConfig(
39+
name='sparse_categorical_crossentropy')
40+
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
41+
name='momentum',
42+
epsilon=0.001,
43+
momentum=0.9,
44+
moving_average_decay=None)
45+
learning_rate: base_configs.LearningRateConfig = (
46+
base_configs.LearningRateConfig(
47+
name='stepwise',
48+
initial_lr=0.01,
49+
examples_per_epoch=1281167,
50+
boundaries=[30, 60],
51+
warmup_epochs=0,
52+
scale_by_batch_size=1. / 128.,
53+
multipliers=[0.01 / 256, 0.001 / 256, 0.0001 / 256]))
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import tensorflow as tf
6+
7+
layers = tf.keras.layers
8+
9+
def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
10+
return tf.keras.regularizers.L2(
11+
l2_weight_decay) if use_l2_regularizer else None
12+
13+
def vgg16(num_classes,
14+
batch_size=None,
15+
use_l2_regularizer=True,
16+
batch_norm_decay=0.9,
17+
batch_norm_epsilon=1e-5):
18+
19+
input_shape = (224, 224, 3)
20+
img_input = layers.Input(shape=input_shape, batch_size=batch_size)
21+
22+
x = img_input
23+
24+
if tf.keras.backend.image_data_format() == 'channels_first':
25+
x = layers.Permute((3, 1, 2))(x)
26+
bn_axis = 1
27+
else: # channels_last
28+
bn_axis = 3
29+
30+
# Block 1
31+
x = layers.Conv2D(64, (3, 3),
32+
padding='same',
33+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
34+
name='block1_conv1')(x)
35+
x = layers.BatchNormalization(
36+
axis=bn_axis,
37+
momentum=batch_norm_decay,
38+
epsilon=batch_norm_epsilon,
39+
name='bn_conv1')(x)
40+
x = layers.Activation('relu')(x)
41+
x = layers.Conv2D(64, (3, 3),
42+
padding='same',
43+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
44+
name='block1_conv2')(x)
45+
x = layers.BatchNormalization(
46+
axis=bn_axis,
47+
momentum=batch_norm_decay,
48+
epsilon=batch_norm_epsilon,
49+
name='bn_conv2')(x)
50+
x = layers.Activation('relu')(x)
51+
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
52+
53+
# Block 2
54+
x = layers.Conv2D(128, (3, 3),
55+
padding='same',
56+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
57+
name='block2_conv1')(x)
58+
x = layers.BatchNormalization(
59+
axis=bn_axis,
60+
momentum=batch_norm_decay,
61+
epsilon=batch_norm_epsilon,
62+
name='bn_conv3')(x)
63+
x = layers.Activation('relu')(x)
64+
x = layers.Conv2D(128, (3, 3),
65+
padding='same',
66+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
67+
name='block2_conv2')(x)
68+
x = layers.BatchNormalization(
69+
axis=bn_axis,
70+
momentum=batch_norm_decay,
71+
epsilon=batch_norm_epsilon,
72+
name='bn_conv4')(x)
73+
x = layers.Activation('relu')(x)
74+
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
75+
76+
# Block 3
77+
x = layers.Conv2D(256, (3, 3),
78+
padding='same',
79+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
80+
name='block3_conv1')(x)
81+
x = layers.BatchNormalization(
82+
axis=bn_axis,
83+
momentum=batch_norm_decay,
84+
epsilon=batch_norm_epsilon,
85+
name='bn_conv5')(x)
86+
x = layers.Activation('relu')(x)
87+
x = layers.Conv2D(256, (3, 3),
88+
padding='same',
89+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
90+
name='block3_conv2')(x)
91+
x = layers.BatchNormalization(
92+
axis=bn_axis,
93+
momentum=batch_norm_decay,
94+
epsilon=batch_norm_epsilon,
95+
name='bn_conv6')(x)
96+
x = layers.Activation('relu')(x)
97+
x = layers.Conv2D(256, (3, 3),
98+
padding='same',
99+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
100+
name='block3_conv3')(x)
101+
x = layers.BatchNormalization(
102+
axis=bn_axis,
103+
momentum=batch_norm_decay,
104+
epsilon=batch_norm_epsilon,
105+
name='bn_conv7')(x)
106+
x = layers.Activation('relu')(x)
107+
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
108+
109+
# Block 4
110+
x = layers.Conv2D(512, (3, 3),
111+
padding='same',
112+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
113+
name='block4_conv1')(x)
114+
x = layers.BatchNormalization(
115+
axis=bn_axis,
116+
momentum=batch_norm_decay,
117+
epsilon=batch_norm_epsilon,
118+
name='bn_conv8')(x)
119+
x = layers.Activation('relu')(x)
120+
x = layers.Conv2D(512, (3, 3),
121+
padding='same',
122+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
123+
name='block4_conv2')(x)
124+
x = layers.BatchNormalization(
125+
axis=bn_axis,
126+
momentum=batch_norm_decay,
127+
epsilon=batch_norm_epsilon,
128+
name='bn_conv9')(x)
129+
x = layers.Activation('relu')(x)
130+
x = layers.Conv2D(512, (3, 3),
131+
padding='same',
132+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
133+
name='block4_conv3')(x)
134+
x = layers.BatchNormalization(
135+
axis=bn_axis,
136+
momentum=batch_norm_decay,
137+
epsilon=batch_norm_epsilon,
138+
name='bn_conv10')(x)
139+
x = layers.Activation('relu')(x)
140+
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
141+
142+
# Block 5
143+
x = layers.Conv2D(512, (3, 3),
144+
padding='same',
145+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
146+
name='block5_conv1')(x)
147+
x = layers.BatchNormalization(
148+
axis=bn_axis,
149+
momentum=batch_norm_decay,
150+
epsilon=batch_norm_epsilon,
151+
name='bn_conv11')(x)
152+
x = layers.Activation('relu')(x)
153+
x = layers.Conv2D(512, (3, 3),
154+
padding='same',
155+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
156+
name='block5_conv2')(x)
157+
x = layers.BatchNormalization(
158+
axis=bn_axis,
159+
momentum=batch_norm_decay,
160+
epsilon=batch_norm_epsilon,
161+
name='bn_conv12')(x)
162+
x = layers.Activation('relu')(x)
163+
x = layers.Conv2D(512, (3, 3),
164+
padding='same',
165+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
166+
name='block5_conv3')(x)
167+
x = layers.BatchNormalization(
168+
axis=bn_axis,
169+
momentum=batch_norm_decay,
170+
epsilon=batch_norm_epsilon,
171+
name='bn_conv13')(x)
172+
x = layers.Activation('relu')(x)
173+
x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
174+
175+
x = layers.Flatten(name='flatten')(x)
176+
x = layers.Dense(4096,
177+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
178+
name='fc1')(x)
179+
x = layers.Activation('relu')(x)
180+
x = layers.Dropout(0.5)(x)
181+
x = layers.Dense(4096,
182+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
183+
name='fc2')(x)
184+
x = layers.Activation('relu')(x)
185+
x = layers.Dropout(0.5)(x)
186+
x = layers.Dense(num_classes,
187+
kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
188+
name='fc1000')(x)
189+
190+
# A softmax that is followed by the model loss must be done cannot be done
191+
# in float16 due to numeric issues. So we pass dtype=float32.
192+
x = layers.Activation('softmax', dtype='float32')(x)
193+
194+
# Create model.
195+
return tf.keras.Model(img_input, x, name='vgg16')
196+

0 commit comments

Comments
 (0)