Skip to content

Commit 4129326

Browse files
Internal change
PiperOrigin-RevId: 273371605
1 parent fa5b66e commit 4129326

File tree

8 files changed

+213
-4
lines changed

8 files changed

+213
-4
lines changed

official/resnet/ctl/ctl_imagenet_main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from absl import app as absl_app
21+
from absl import app
2222
from absl import flags
2323
from absl import logging
2424
import tensorflow as tf
@@ -181,6 +181,12 @@ def run(flags_obj):
181181
enable_eager=flags_obj.enable_eager,
182182
enable_xla=flags_obj.enable_xla)
183183

184+
dtype = flags_core.get_tf_dtype(flags_obj)
185+
if dtype == tf.bfloat16:
186+
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
187+
'mixed_bfloat16')
188+
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
189+
184190
# TODO(anj-s): Set data_format without using Keras.
185191
data_format = flags_obj.data_format
186192
if data_format is None:
@@ -375,4 +381,4 @@ def main(_):
375381
common.define_keras_flags()
376382
ctl_common.define_ctl_flags()
377383
flags.adopt_module_key_flags(ctl_common)
378-
absl_app.run(main)
384+
app.run(main)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test the ResNet model with ImageNet data using CTL."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tempfile import mkdtemp
22+
import tensorflow as tf
23+
24+
from tensorflow.python.platform import googletest
25+
from official.resnet.ctl import ctl_common
26+
from official.resnet.ctl import ctl_imagenet_main
27+
from official.vision.image_classification import imagenet_preprocessing
28+
from official.vision.image_classification import common
29+
from official.utils.misc import keras_utils
30+
from official.utils.testing import integration
31+
32+
33+
class CtlImagenetTest(googletest.TestCase):
34+
"""Unit tests for Keras ResNet with ImageNet using CTL."""
35+
36+
_extra_flags = [
37+
'-batch_size', '4',
38+
'-train_steps', '4',
39+
'-use_synthetic_data', 'true'
40+
]
41+
_tempdir = None
42+
43+
def get_temp_dir(self):
44+
if not self._tempdir:
45+
self._tempdir = mkdtemp(dir=googletest.GetTempDir())
46+
return self._tempdir
47+
48+
@classmethod
49+
def setUpClass(cls): # pylint: disable=invalid-name
50+
super(CtlImagenetTest, cls).setUpClass()
51+
common.define_keras_flags()
52+
ctl_common.define_ctl_flags()
53+
54+
def setUp(self):
55+
super(CtlImagenetTest, self).setUp()
56+
if not keras_utils.is_v2_0():
57+
tf.compat.v1.enable_v2_behavior()
58+
imagenet_preprocessing.NUM_IMAGES['validation'] = 4
59+
60+
def tearDown(self):
61+
super(CtlImagenetTest, self).tearDown()
62+
tf.io.gfile.rmtree(self.get_temp_dir())
63+
64+
def test_end_to_end_tpu(self):
65+
"""Test Keras model with TPU distribution strategy."""
66+
67+
extra_flags = [
68+
'-distribution_strategy', 'tpu',
69+
'-model_dir', 'ctl_imagenet_tpu_dist_strat',
70+
'-data_format', 'channels_last',
71+
'-use_tf_function', 'true',
72+
'-single_l2_loss_op', 'true',
73+
]
74+
extra_flags = extra_flags + self._extra_flags
75+
76+
integration.run_synthetic(
77+
main=ctl_imagenet_main.run,
78+
tmp_root=self.get_temp_dir(),
79+
extra_flags=extra_flags
80+
)
81+
82+
def test_end_to_end_tpu_bf16(self):
83+
"""Test Keras model with TPU and bfloat16 activation."""
84+
85+
extra_flags = [
86+
'-distribution_strategy', 'tpu',
87+
'-model_dir', 'ctl_imagenet_tpu_dist_strat_bf16',
88+
'-data_format', 'channels_last',
89+
'-use_tf_function', 'true',
90+
'-single_l2_loss_op', 'true',
91+
'-dtype', 'bf16',
92+
]
93+
extra_flags = extra_flags + self._extra_flags
94+
95+
integration.run_synthetic(
96+
main=ctl_imagenet_main.run,
97+
tmp_root=self.get_temp_dir(),
98+
extra_flags=extra_flags
99+
)
100+
101+
102+
if __name__ == '__main__':
103+
googletest.main()

official/transformer/model/beam_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def inf(dtype):
3535
Returns:
3636
A very large value.
3737
"""
38-
if dtype == "float32":
38+
if dtype == "float32" or dtype == "bfloat16":
3939
return 1e7
4040
elif dtype == "float16":
4141
# Disable no-member lint error, as the linter thinks np.float16 does not

official/transformer/v2/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def get_config(self):
386386

387387
def call(self, x, epsilon=1e-6):
388388
input_dtype = x.dtype
389-
if input_dtype == tf.float16:
389+
if input_dtype == tf.float16 or input_dtype == tf.bfloat16:
390390
x = tf.cast(x, tf.float32)
391391
mean = tf.reduce_mean(x, axis=[-1], keepdims=True)
392392
variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keepdims=True)

official/transformer/v2/transformer_main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ def __init__(self, flags_obj):
171171
"mixed_float16", loss_scale=loss_scale)
172172
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
173173

174+
if params["dtype"] == tf.bfloat16:
175+
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
176+
"mixed_bfloat16")
177+
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
178+
174179
self.distribution_strategy = distribution_utils.get_distribution_strategy(
175180
distribution_strategy=flags_obj.distribution_strategy,
176181
num_gpus=num_gpus,

official/utils/flags/_performance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
# Map string to TensorFlow dtype
3030
DTYPE_MAP = {
3131
"fp16": tf.float16,
32+
"bf16": tf.bfloat16,
3233
"fp32": tf.float32,
3334
}
3435

official/vision/image_classification/resnet_imagenet_main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def run(flags_obj):
6767
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
6868
if not keras_utils.is_v2_0():
6969
raise ValueError('--dtype=fp16 is not supported in TensorFlow 1.')
70+
elif dtype == tf.bfloat16:
71+
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
72+
'mixed_bfloat16')
73+
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
7074

7175
data_format = flags_obj.data_format
7276
if data_format is None:
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Test the keras ResNet model with ImageNet data on TPU."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import tensorflow as tf
22+
23+
from official.utils.misc import keras_utils
24+
from official.utils.testing import integration
25+
from official.vision.image_classification import imagenet_preprocessing
26+
from official.vision.image_classification import resnet_imagenet_main
27+
28+
29+
class KerasImagenetTest(tf.test.TestCase):
30+
"""Unit tests for Keras ResNet with ImageNet."""
31+
32+
_extra_flags = [
33+
"-batch_size", "4",
34+
"-train_steps", "1",
35+
"-use_synthetic_data", "true"
36+
]
37+
_tempdir = None
38+
39+
@classmethod
40+
def setUpClass(cls): # pylint: disable=invalid-name
41+
super(KerasImagenetTest, cls).setUpClass()
42+
resnet_imagenet_main.define_imagenet_keras_flags()
43+
44+
def setUp(self):
45+
super(KerasImagenetTest, self).setUp()
46+
imagenet_preprocessing.NUM_IMAGES["validation"] = 4
47+
48+
def tearDown(self):
49+
super(KerasImagenetTest, self).tearDown()
50+
tf.io.gfile.rmtree(self.get_temp_dir())
51+
52+
def test_end_to_end_tpu(self):
53+
"""Test Keras model with TPU distribution strategy."""
54+
config = keras_utils.get_config_proto_v1()
55+
tf.compat.v1.enable_eager_execution(config=config)
56+
57+
extra_flags = [
58+
"-distribution_strategy", "tpu",
59+
"-data_format", "channels_last",
60+
]
61+
extra_flags = extra_flags + self._extra_flags
62+
63+
integration.run_synthetic(
64+
main=resnet_imagenet_main.run,
65+
tmp_root=self.get_temp_dir(),
66+
extra_flags=extra_flags
67+
)
68+
69+
def test_end_to_end_tpu_bf16(self):
70+
"""Test Keras model with TPU and bfloat16 activation."""
71+
config = keras_utils.get_config_proto_v1()
72+
tf.compat.v1.enable_eager_execution(config=config)
73+
74+
extra_flags = [
75+
"-distribution_strategy", "tpu",
76+
"-data_format", "channels_last",
77+
"-dtype", "bf16",
78+
]
79+
extra_flags = extra_flags + self._extra_flags
80+
81+
integration.run_synthetic(
82+
main=resnet_imagenet_main.run,
83+
tmp_root=self.get_temp_dir(),
84+
extra_flags=extra_flags
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
tf.compat.v1.enable_v2_behavior()
90+
tf.test.main()

0 commit comments

Comments
 (0)