Skip to content

Commit 937a530

Browse files
authored
Keras model benchmark (#4476)
* Add callbacks * Add readme * update readme * fix some comments * Address all comments * Update docstrings * Add method docstrings * Update callbacks * Add comments on global_step initialization * Some updates * Address comments
1 parent 7d0fcd0 commit 937a530

File tree

5 files changed

+406
-0
lines changed

5 files changed

+406
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Keras Application Models Benchmark
2+
## Overview
3+
This provides a single scaffold to benchmark the Keras built-in application [models](https://keras.io/applications/). All the models are for image classification applications, and include:
4+
5+
- Xception
6+
- VGG16
7+
- VGG19
8+
- ResNet50
9+
- InceptionV3
10+
- InceptionResNetV2
11+
- MobileNet
12+
- DenseNet
13+
- NASNet
14+
15+
## Dataset
16+
Synthetic dataset is used for the benchmark.
17+
18+
## Callbacks
19+
Two custom callbacks are provided for model benchmarking: ExamplesPerSecondCallback and LoggingMetricCallback. For each callback, `epoch_based` and `batch_based` options are available to set the benchmark level. Check [model_callbacks.py](model_callbacks.py) for more details.
20+
21+
## Running Code
22+
To benchmark a model, use `--model` to specify the model name, and issue the following command:
23+
```
24+
python benchmark_main.py --model=resnet
25+
```
26+
Arguments:
27+
* `--model`: Which model to be benchmarked. The model name is defined as the keys of `MODELS` in [benchmark_main.py](benchmark_main.py).
28+
* `--callbacks`: To specify a list of callbacks.
29+
30+
Use the `--help` or `-h` flag to get a full list of possible arguments.

official/keras_application_models/__init__.py

Whitespace-only changes.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Benchmark on the keras built-in application models."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
# pylint: disable=g-bad-import-order
21+
import numpy as np
22+
from absl import app as absl_app
23+
from absl import flags
24+
import tensorflow as tf
25+
# pylint: enable=g-bad-import-order
26+
27+
from official.keras_application_models import dataset
28+
from official.keras_application_models import model_callbacks
29+
from official.utils.flags import core as flags_core
30+
from official.utils.logs import logger
31+
32+
# Define a dictionary that maps model names to their model classes inside Keras
33+
MODELS = {
34+
"vgg16": tf.keras.applications.VGG16,
35+
"vgg19": tf.keras.applications.VGG19,
36+
"inceptionv3": tf.keras.applications.InceptionV3,
37+
"xception": tf.keras.applications.Xception,
38+
"resnet50": tf.keras.applications.ResNet50,
39+
"inceptionresnetv2": tf.keras.applications.InceptionResNetV2,
40+
"mobilenet": tf.keras.applications.MobileNet,
41+
"densenet121": tf.keras.applications.DenseNet121,
42+
"densenet169": tf.keras.applications.DenseNet169,
43+
"densenet201": tf.keras.applications.DenseNet201,
44+
# TODO(b/80431378)
45+
# "nasnetlarge": tf.keras.applications.NASNetLarge,
46+
# "nasnetmobile": tf.keras.applications.NASNetMobile,
47+
}
48+
49+
50+
def run_keras_model_benchmark(_):
51+
"""Run the benchmark on keras model."""
52+
# Ensure a valid model name was supplied via command line argument
53+
if FLAGS.model not in MODELS.keys():
54+
raise AssertionError("The --model command line argument should "
55+
"be a key in the `MODELS` dictionary.")
56+
57+
# Load the model
58+
tf.logging.info("Benchmark on {} model...".format(FLAGS.model))
59+
keras_model = MODELS[FLAGS.model]
60+
model = keras_model(weights=None)
61+
62+
# Get dataset
63+
dataset_name = "ImageNet"
64+
if FLAGS.use_synthetic_data:
65+
tf.logging.info("Using synthetic dataset...")
66+
dataset_name += "_Synthetic"
67+
train_num_images = FLAGS.batch_size
68+
val_num_images = FLAGS.batch_size
69+
train_dataset = dataset.generate_synthetic_input_dataset(
70+
FLAGS.model, train_num_images)
71+
val_dataset = dataset.generate_synthetic_input_dataset(
72+
FLAGS.model, val_num_images)
73+
else:
74+
raise ValueError("Only synthetic dataset is supported!")
75+
76+
# If run with multiple GPUs
77+
num_gpus = flags_core.get_num_gpus(FLAGS)
78+
if num_gpus > 0:
79+
model = tf.keras.utils.multi_gpu_model(model, gpus=num_gpus)
80+
81+
# Configure the model
82+
model.compile(loss="categorical_crossentropy",
83+
optimizer="sgd",
84+
metrics=["accuracy"])
85+
86+
# Create benchmark logger for benchmark logging
87+
run_params = {
88+
"batch_size": FLAGS.batch_size,
89+
"synthetic_data": FLAGS.use_synthetic_data,
90+
"train_epochs": FLAGS.train_epochs
91+
}
92+
93+
benchmark_logger = logger.get_benchmark_logger()
94+
benchmark_logger.log_run_info(
95+
model_name=FLAGS.model,
96+
dataset_name=dataset_name,
97+
run_params=run_params,
98+
test_id=FLAGS.benchmark_test_id)
99+
100+
# Create callbacks that log metric values about the training and evaluation
101+
callbacks = model_callbacks.get_model_callbacks(
102+
FLAGS.callbacks,
103+
batch_size=FLAGS.batch_size,
104+
metric_logger=benchmark_logger)
105+
# Train and evaluate the model
106+
history = model.fit(
107+
train_dataset,
108+
epochs=FLAGS.train_epochs,
109+
callbacks=callbacks,
110+
validation_data=val_dataset,
111+
steps_per_epoch=int(np.ceil(train_num_images / FLAGS.batch_size)),
112+
validation_steps=int(np.ceil(val_num_images / FLAGS.batch_size))
113+
)
114+
115+
tf.logging.info("Logging the evaluation results...")
116+
for epoch in range(FLAGS.train_epochs):
117+
eval_results = {
118+
"accuracy": history.history["val_acc"][epoch],
119+
"loss": history.history["val_loss"][epoch],
120+
tf.GraphKeys.GLOBAL_STEP: (epoch + 1) * np.ceil(
121+
train_num_images/FLAGS.batch_size)
122+
}
123+
benchmark_logger.log_evaluation_result(eval_results)
124+
125+
# Clear the session explicitly to avoid session delete error
126+
tf.keras.backend.clear_session()
127+
128+
129+
def define_keras_benchmark_flags():
130+
"""Add flags for keras built-in application models."""
131+
flags_core.define_base(hooks=False)
132+
flags_core.define_performance()
133+
flags_core.define_image()
134+
flags_core.define_benchmark()
135+
flags.adopt_module_key_flags(flags_core)
136+
137+
flags_core.set_defaults(
138+
data_format="channels_last",
139+
use_synthetic_data=True,
140+
batch_size=32,
141+
train_epochs=2)
142+
143+
flags.DEFINE_enum(
144+
name="model", default=None,
145+
enum_values=MODELS.keys(), case_sensitive=False,
146+
help=flags_core.help_wrap(
147+
"Model to be benchmarked."))
148+
149+
flags.DEFINE_list(
150+
name="callbacks",
151+
default=["ExamplesPerSecondCallback", "LoggingMetricCallback"],
152+
help=flags_core.help_wrap(
153+
"A list of (case insensitive) strings to specify the names of "
154+
"callbacks. For example: `--callbacks ExamplesPerSecondCallback,"
155+
"LoggingMetricCallback`"))
156+
157+
158+
def main(_):
159+
with logger.benchmark_context(FLAGS):
160+
run_keras_model_benchmark(FLAGS)
161+
162+
if __name__ == "__main__":
163+
tf.logging.set_verbosity(tf.logging.INFO)
164+
define_keras_benchmark_flags()
165+
FLAGS = flags.FLAGS
166+
absl_app.run(main)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Prepare dataset for keras model benchmark."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import tensorflow as tf
21+
22+
# Default values for dataset.
23+
_NUM_CHANNELS = 3
24+
_NUM_CLASSES = 1000
25+
26+
27+
def _get_default_image_size(model):
28+
"""Provide default image size for each model."""
29+
image_size = (224, 224)
30+
if model in ["inception", "xception", "inceptionresnet"]:
31+
image_size = (299, 299)
32+
elif model in ["nasnetlarge"]:
33+
image_size = (331, 331)
34+
return image_size
35+
36+
37+
def generate_synthetic_input_dataset(model, num_imgs):
38+
"""Generate synthetic dataset."""
39+
image_size = _get_default_image_size(model)
40+
input_shape = (num_imgs,) + image_size + (_NUM_CHANNELS,)
41+
42+
images = tf.zeros(input_shape, dtype=tf.float32)
43+
labels = tf.zeros((num_imgs, _NUM_CLASSES), dtype=tf.float32)
44+
45+
return tf.data.Dataset.from_tensors((images, labels)).repeat()

0 commit comments

Comments
 (0)