Skip to content

Commit 83ad7bd

Browse files
reedwmtensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 381302776
1 parent 7ee6089 commit 83ad7bd

File tree

1 file changed

+11
-41
lines changed

1 file changed

+11
-41
lines changed

official/modeling/performance.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,16 @@
1414

1515
"""Functions and classes related to training performance."""
1616

17-
from absl import logging
1817
import tensorflow as tf
1918

2019

2120
def configure_optimizer(optimizer,
2221
use_float16=False,
2322
use_graph_rewrite=False,
24-
loss_scale='dynamic',
25-
use_experimental_api=False):
23+
loss_scale='dynamic'):
2624
"""Configures optimizer object with performance options."""
27-
if use_experimental_api:
28-
logging.warning('Passing use_experimental_api=True is deprecated. The '
29-
'argument will be removed in the future.')
3025
if use_float16:
31-
# TODO(b/171936854): Move all methods to non-experimental api.
32-
if use_experimental_api:
33-
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
34-
# in compile() with the "mixed_float16" policy, but since we do not call
35-
# compile(), we must wrap the optimizer manually.
36-
optimizer = (
37-
tf.keras.mixed_precision.experimental.LossScaleOptimizer(
38-
optimizer, loss_scale=loss_scale))
39-
elif loss_scale == 'dynamic':
26+
if loss_scale == 'dynamic':
4027
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
4128
else:
4229
# loss_scale is a number. We interpret that as a fixed loss scale.
@@ -52,34 +39,17 @@ def configure_optimizer(optimizer,
5239
return optimizer
5340

5441

55-
def set_mixed_precision_policy(dtype, loss_scale=None,
56-
use_experimental_api=False):
57-
"""Sets mix precision policy."""
58-
if use_experimental_api:
59-
logging.warning('Passing use_experimental_api=True is deprecated. The '
60-
'argument will be removed in the future.')
61-
assert use_experimental_api or loss_scale is None, (
62-
'loss_scale cannot be specified if use_experimental_api is False. If the '
63-
'non-experimental API is used, specify the loss scaling configuration '
64-
'when creating the LossScaleOptimizer instead.'
65-
)
42+
def set_mixed_precision_policy(dtype, loss_scale=None):
43+
"""Sets the global `tf.keras.mixed_precision.Policy`."""
44+
# TODO(b/191894773): Remove loss_scale argument
45+
assert loss_scale is None, (
46+
'The loss_scale argument must be None. The argument exists for '
47+
'historical reasons and will be removed soon.')
6648
if dtype == tf.float16:
67-
# TODO(b/171936854): Move all methods to non-experimental api.
68-
if use_experimental_api:
69-
policy = tf.keras.mixed_precision.experimental.Policy(
70-
'mixed_float16', loss_scale=loss_scale)
71-
tf.keras.mixed_precision.experimental.set_policy(policy)
72-
else:
73-
tf.keras.mixed_precision.set_global_policy('mixed_float16')
49+
tf.keras.mixed_precision.set_global_policy('mixed_float16')
7450
elif dtype == tf.bfloat16:
75-
if use_experimental_api:
76-
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
77-
else:
78-
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
51+
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
7952
elif dtype == tf.float32:
80-
if use_experimental_api:
81-
tf.keras.mixed_precision.experimental.set_policy('float32')
82-
else:
83-
tf.keras.mixed_precision.set_global_policy('float32')
53+
tf.keras.mixed_precision.set_global_policy('float32')
8454
else:
8555
raise ValueError('Unexpected dtype: %s' % dtype)

0 commit comments

Comments
 (0)