Skip to content

Commit 587f579

Browse files
author
Taylor Robie
authored
Add reference data tests to official. (#3723)
* Add golden test util to streamline symbolic and numerical comparison to reference graphs, and apply golden tests to ResNet. update tests use more concise logic for path property delint add some comments delint address PR comments make resnet tests more concise, and supress warning test in py2 change resnet name template more shuffling of data dirs address PR comments and add tensorflow version info Remove subTest due to py2 switch from tf.__version__ to tf.VERSION, and include tf.GIT_VERSION supress lint error from json load unpack * address PR comments * address PR comments * delint
1 parent 1730eed commit 587f579

File tree

58 files changed

+686
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+686
-0
lines changed

official/resnet/layer_test.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test that the definitions of ResNet layers haven't changed.
16+
17+
These tests will fail if either:
18+
a) The graph of a resnet layer changes and the change is significant enough
19+
that it can no longer load existing checkpoints.
20+
b) The numerical results produced by the layer change.
21+
22+
A warning will be issued if the graph changes, but the checkpoint still loads.
23+
24+
In the event that a layer change is intended, or the TensorFlow implementation
25+
of a layer changes (and thus changes the graph), regenerate using the command:
26+
27+
$ python3 layer_test.py -regen
28+
"""
29+
30+
from __future__ import absolute_import
31+
from __future__ import division
32+
from __future__ import print_function
33+
34+
import sys
35+
36+
import tensorflow as tf # pylint: disable=g-bad-import-order
37+
from official.resnet import resnet_model
38+
from official.utils.testing import reference_data
39+
40+
41+
DATA_FORMAT = "channels_last" # CPU instructions often preclude channels_first
42+
BATCH_SIZE = 32
43+
BLOCK_TESTS = [
44+
dict(bottleneck=True, projection=True, version=1, width=8, channels=4),
45+
dict(bottleneck=True, projection=True, version=2, width=8, channels=4),
46+
dict(bottleneck=True, projection=False, version=1, width=8, channels=4),
47+
dict(bottleneck=True, projection=False, version=2, width=8, channels=4),
48+
dict(bottleneck=False, projection=True, version=1, width=8, channels=4),
49+
dict(bottleneck=False, projection=True, version=2, width=8, channels=4),
50+
dict(bottleneck=False, projection=False, version=1, width=8, channels=4),
51+
dict(bottleneck=False, projection=False, version=2, width=8, channels=4),
52+
]
53+
54+
55+
class BaseTest(reference_data.BaseTest):
56+
"""Tests for core ResNet layers."""
57+
58+
@property
59+
def test_name(self):
60+
return "resnet"
61+
62+
def _batch_norm_ops(self, test=False):
63+
name = "batch_norm"
64+
65+
g = tf.Graph()
66+
with g.as_default():
67+
tf.set_random_seed(self.name_to_seed(name))
68+
input_tensor = tf.get_variable(
69+
"input_tensor", dtype=tf.float32,
70+
initializer=tf.random_uniform((32, 16, 16, 3), maxval=1)
71+
)
72+
layer = resnet_model.batch_norm(
73+
inputs=input_tensor, data_format=DATA_FORMAT, training=True)
74+
75+
self._save_or_test_ops(
76+
name=name, graph=g, ops_to_eval=[input_tensor, layer], test=test,
77+
correctness_function=self.default_correctness_function
78+
)
79+
80+
def make_projection(self, filters_out, strides, data_format):
81+
"""1D convolution with stride projector.
82+
83+
Args:
84+
filters_out: Number of filters in the projection.
85+
strides: Stride length for convolution.
86+
data_format: channels_first or channels_last
87+
88+
Returns:
89+
A CNN projector function with kernel_size 1.
90+
"""
91+
def projection_shortcut(inputs):
92+
return resnet_model.conv2d_fixed_padding(
93+
inputs=inputs, filters=filters_out, kernel_size=1, strides=strides,
94+
data_format=data_format)
95+
return projection_shortcut
96+
97+
def _resnet_block_ops(self, test, batch_size, bottleneck, projection,
98+
version, width, channels):
99+
"""Test whether resnet block construction has changed.
100+
101+
Args:
102+
test: Whether or not to run as a test case.
103+
batch_size: Number of points in the fake image. This is needed due to
104+
batch normalization.
105+
bottleneck: Whether or not to use bottleneck layers.
106+
projection: Whether or not to project the input.
107+
version: Which version of ResNet to test.
108+
width: The width of the fake image.
109+
channels: The number of channels in the fake image.
110+
"""
111+
112+
name = "batch-size-{}_{}{}_version-{}_width-{}_channels-{}".format(
113+
batch_size,
114+
"bottleneck" if bottleneck else "building",
115+
"_projection" if projection else "",
116+
version,
117+
width,
118+
channels
119+
)
120+
121+
if version == 1:
122+
block_fn = resnet_model._building_block_v1
123+
if bottleneck:
124+
block_fn = resnet_model._bottleneck_block_v1
125+
else:
126+
block_fn = resnet_model._building_block_v2
127+
if bottleneck:
128+
block_fn = resnet_model._bottleneck_block_v2
129+
130+
g = tf.Graph()
131+
with g.as_default():
132+
tf.set_random_seed(self.name_to_seed(name))
133+
strides = 1
134+
channels_out = channels
135+
projection_shortcut = None
136+
if projection:
137+
strides = 2
138+
channels_out *= strides
139+
projection_shortcut = self.make_projection(
140+
filters_out=channels_out, strides=strides, data_format=DATA_FORMAT)
141+
142+
filters = channels_out
143+
if bottleneck:
144+
filters = channels_out // 4
145+
146+
input_tensor = tf.get_variable(
147+
"input_tensor", dtype=tf.float32,
148+
initializer=tf.random_uniform((batch_size, width, width, channels),
149+
maxval=1)
150+
)
151+
152+
layer = block_fn(inputs=input_tensor, filters=filters, training=True,
153+
projection_shortcut=projection_shortcut, strides=strides,
154+
data_format=DATA_FORMAT)
155+
156+
self._save_or_test_ops(
157+
name=name, graph=g, ops_to_eval=[input_tensor, layer], test=test,
158+
correctness_function=self.default_correctness_function
159+
)
160+
161+
def test_batch_norm(self):
162+
self._batch_norm_ops(test=True)
163+
164+
def test_block_0(self):
165+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[0])
166+
167+
def test_block_1(self):
168+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[1])
169+
170+
def test_block_2(self):
171+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[2])
172+
173+
def test_block_3(self):
174+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[3])
175+
176+
def test_block_4(self):
177+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[4])
178+
179+
def test_block_5(self):
180+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[5])
181+
182+
def test_block_6(self):
183+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[6])
184+
185+
def test_block_7(self):
186+
self._resnet_block_ops(test=True, batch_size=BATCH_SIZE, **BLOCK_TESTS[7])
187+
188+
def regenerate(self):
189+
"""Create reference data files for ResNet layer tests."""
190+
self._batch_norm_ops(test=False)
191+
for block_params in BLOCK_TESTS:
192+
self._resnet_block_ops(test=False, batch_size=BATCH_SIZE, **block_params)
193+
194+
195+
if __name__ == "__main__":
196+
reference_data.main(argv=sys.argv, test_class=BaseTest)

0 commit comments

Comments
 (0)