Skip to content

Commit 2d16421

Browse files
Merge pull request #10449 from miguelCalado:vgg
PiperOrigin-RevId: 421667465
2 parents c9a7e0b + 3e7fe8a commit 2d16421

File tree

8 files changed

+421
-7
lines changed

8 files changed

+421
-7
lines changed

official/legacy/image_classification/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ python3 classifier_trainer.py \
152152
--config_file=configs/examples/resnet/imagenet/tpu.yaml
153153
```
154154

155+
### VGG-16
156+
157+
#### On GPU:
158+
```bash
159+
python3 classifier_trainer.py \
160+
--mode=train_and_eval \
161+
--model_type=vgg \
162+
--dataset=imagenet \
163+
--model_dir=$MODEL_DIR \
164+
--data_dir=$DATA_DIR \
165+
--config_file=configs/examples/vgg/imagenet/gpu.yaml \
166+
--params_override='runtime.num_gpus=$NUM_GPUS'
167+
```
168+
155169
### EfficientNet
156170
**Note: EfficientNet development is a work in progress.**
157171
#### On GPU:

official/legacy/image_classification/classifier_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from official.legacy.image_classification.efficientnet import efficientnet_model
3333
from official.legacy.image_classification.resnet import common
3434
from official.legacy.image_classification.resnet import resnet_model
35+
from official.legacy.image_classification.vgg import vgg_model
3536
from official.modeling import hyperparams
3637
from official.modeling import performance
3738
from official.utils import hyperparams_flags
@@ -43,6 +44,7 @@ def get_models() -> Mapping[str, tf.keras.Model]:
4344
return {
4445
'efficientnet': efficientnet_model.EfficientNet.from_name,
4546
'resnet': resnet_model.resnet50,
47+
'vgg': vgg_model.vgg16,
4648
}
4749

4850

official/legacy/image_classification/classifier_trainer_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
# Lint as: python3
1616
"""Unit tests for the classifier trainer models."""
1717

18-
from __future__ import absolute_import
19-
from __future__ import division
20-
from __future__ import print_function
21-
2218
import functools
2319
import json
2420

@@ -53,6 +49,7 @@ def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
5349
model=[
5450
'efficientnet',
5551
'resnet',
52+
'vgg',
5653
],
5754
dataset=[
5855
'imagenet',
@@ -149,6 +146,7 @@ def test_end_to_end_train_and_eval(self, distribution, model, dataset):
149146
model=[
150147
'efficientnet',
151148
'resnet',
149+
'vgg',
152150
],
153151
dataset='imagenet',
154152
dtype='float16',
@@ -193,6 +191,7 @@ def test_gpu_train(self, distribution, model, dataset, dtype):
193191
model=[
194192
'efficientnet',
195193
'resnet',
194+
'vgg',
196195
],
197196
dataset='imagenet',
198197
dtype='bfloat16',

official/legacy/image_classification/configs/configs.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@
1414

1515
# Lint as: python3
1616
"""Configuration utils for image classification experiments."""
17-
from __future__ import absolute_import
18-
from __future__ import division
19-
from __future__ import print_function
2017

2118
import dataclasses
2219

2320
from official.legacy.image_classification import dataset_factory
2421
from official.legacy.image_classification.configs import base_configs
2522
from official.legacy.image_classification.efficientnet import efficientnet_config
2623
from official.legacy.image_classification.resnet import resnet_config
24+
from official.legacy.image_classification.vgg import vgg_config
2725

2826

2927
@dataclasses.dataclass
@@ -92,12 +90,38 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
9290
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
9391

9492

93+
@dataclasses.dataclass
94+
class VGGImagenetConfig(base_configs.ExperimentConfig):
95+
"""Base configuration to train vgg-16 on ImageNet."""
96+
export: base_configs.ExportConfig = base_configs.ExportConfig()
97+
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
98+
train_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
99+
split='train', one_hot=False, mean_subtract=True, standardize=True)
100+
validation_dataset: dataset_factory.DatasetConfig = dataset_factory.ImageNetConfig(
101+
split='validation', one_hot=False, mean_subtract=True, standardize=True)
102+
train: base_configs.TrainConfig = base_configs.TrainConfig(
103+
resume_checkpoint=True,
104+
epochs=90,
105+
steps=None,
106+
callbacks=base_configs.CallbacksConfig(
107+
enable_checkpoint_and_export=True, enable_tensorboard=True),
108+
metrics=['accuracy', 'top_5'],
109+
time_history=base_configs.TimeHistoryConfig(log_steps=100),
110+
tensorboard=base_configs.TensorBoardConfig(
111+
track_lr=True, write_model_weights=False),
112+
set_epoch_loop=False)
113+
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
114+
epochs_between_evals=1, steps=None)
115+
model: base_configs.ModelConfig = vgg_config.VGGModelConfig()
116+
117+
95118
def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
96119
"""Given model and dataset names, return the ExperimentConfig."""
97120
dataset_model_config_map = {
98121
'imagenet': {
99122
'efficientnet': EfficientNetImageNetConfig(),
100123
'resnet': ResNetImagenetConfig(),
124+
'vgg': VGGImagenetConfig(),
101125
}
102126
}
103127
try:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Training configuration for VGG-16 trained on ImageNet on GPUs.
2+
# Reaches > 72.8% within 90 epochs.
3+
# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
4+
runtime:
5+
distribution_strategy: 'mirrored'
6+
num_gpus: 1
7+
batchnorm_spatial_persistent: true
8+
train_dataset:
9+
name: 'imagenet2012'
10+
data_dir: null
11+
builder: 'records'
12+
split: 'train'
13+
image_size: 224
14+
num_classes: 1000
15+
num_examples: 1281167
16+
batch_size: 128
17+
use_per_replica_batch_size: true
18+
dtype: 'float32'
19+
mean_subtract: true
20+
standardize: true
21+
validation_dataset:
22+
name: 'imagenet2012'
23+
data_dir: null
24+
builder: 'records'
25+
split: 'validation'
26+
image_size: 224
27+
num_classes: 1000
28+
num_examples: 50000
29+
batch_size: 128
30+
use_per_replica_batch_size: true
31+
dtype: 'float32'
32+
mean_subtract: true
33+
standardize: true
34+
model:
35+
name: 'vgg'
36+
optimizer:
37+
name: 'momentum'
38+
momentum: 0.9
39+
epsilon: 0.001
40+
loss:
41+
label_smoothing: 0.0
42+
train:
43+
resume_checkpoint: true
44+
epochs: 90
45+
evaluation:
46+
epochs_between_evals: 1
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Configuration definitions for VGG losses, learning rates, and optimizers."""
17+
18+
import dataclasses
19+
from official.legacy.image_classification.configs import base_configs
20+
from official.modeling.hyperparams import base_config
21+
22+
23+
@dataclasses.dataclass
24+
class VGGModelConfig(base_configs.ModelConfig):
25+
"""Configuration for the VGG model."""
26+
name: str = 'VGG'
27+
num_classes: int = 1000
28+
model_params: base_config.Config = dataclasses.field(default_factory=lambda: { # pylint:disable=g-long-lambda
29+
'num_classes': 1000,
30+
'batch_size': None,
31+
'use_l2_regularizer': True
32+
})
33+
loss: base_configs.LossConfig = base_configs.LossConfig(
34+
name='sparse_categorical_crossentropy')
35+
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
36+
name='momentum', epsilon=0.001, momentum=0.9, moving_average_decay=None)
37+
learning_rate: base_configs.LearningRateConfig = (
38+
base_configs.LearningRateConfig(
39+
name='stepwise',
40+
initial_lr=0.01,
41+
examples_per_epoch=1281167,
42+
boundaries=[30, 60],
43+
warmup_epochs=0,
44+
scale_by_batch_size=1. / 256.,
45+
multipliers=[0.01 / 256, 0.001 / 256, 0.0001 / 256]))

0 commit comments

Comments
 (0)