14
14
15
15
"""Functions and classes related to training performance."""
16
16
17
- from absl import logging
18
17
import tensorflow as tf
19
18
20
19
21
20
def configure_optimizer (optimizer ,
22
21
use_float16 = False ,
23
22
use_graph_rewrite = False ,
24
- loss_scale = 'dynamic' ,
25
- use_experimental_api = False ):
23
+ loss_scale = 'dynamic' ):
26
24
"""Configures optimizer object with performance options."""
27
- if use_experimental_api :
28
- logging .warning ('Passing use_experimental_api=True is deprecated. The '
29
- 'argument will be removed in the future.' )
30
25
if use_float16 :
31
- # TODO(b/171936854): Move all methods to non-experimental api.
32
- if use_experimental_api :
33
- # Wraps optimizer with a LossScaleOptimizer. This is done automatically
34
- # in compile() with the "mixed_float16" policy, but since we do not call
35
- # compile(), we must wrap the optimizer manually.
36
- optimizer = (
37
- tf .keras .mixed_precision .experimental .LossScaleOptimizer (
38
- optimizer , loss_scale = loss_scale ))
39
- elif loss_scale == 'dynamic' :
26
+ if loss_scale == 'dynamic' :
40
27
optimizer = tf .keras .mixed_precision .LossScaleOptimizer (optimizer )
41
28
else :
42
29
# loss_scale is a number. We interpret that as a fixed loss scale.
@@ -52,34 +39,17 @@ def configure_optimizer(optimizer,
52
39
return optimizer
53
40
54
41
55
- def set_mixed_precision_policy (dtype , loss_scale = None ,
56
- use_experimental_api = False ):
57
- """Sets mix precision policy."""
58
- if use_experimental_api :
59
- logging .warning ('Passing use_experimental_api=True is deprecated. The '
60
- 'argument will be removed in the future.' )
61
- assert use_experimental_api or loss_scale is None , (
62
- 'loss_scale cannot be specified if use_experimental_api is False. If the '
63
- 'non-experimental API is used, specify the loss scaling configuration '
64
- 'when creating the LossScaleOptimizer instead.'
65
- )
42
+ def set_mixed_precision_policy (dtype , loss_scale = None ):
43
+ """Sets the global `tf.keras.mixed_precision.Policy`."""
44
+ # TODO(b/191894773): Remove loss_scale argument
45
+ assert loss_scale is None , (
46
+ 'The loss_scale argument must be None. The argument exists for '
47
+ 'historical reasons and will be removed soon.' )
66
48
if dtype == tf .float16 :
67
- # TODO(b/171936854): Move all methods to non-experimental api.
68
- if use_experimental_api :
69
- policy = tf .keras .mixed_precision .experimental .Policy (
70
- 'mixed_float16' , loss_scale = loss_scale )
71
- tf .keras .mixed_precision .experimental .set_policy (policy )
72
- else :
73
- tf .keras .mixed_precision .set_global_policy ('mixed_float16' )
49
+ tf .keras .mixed_precision .set_global_policy ('mixed_float16' )
74
50
elif dtype == tf .bfloat16 :
75
- if use_experimental_api :
76
- tf .keras .mixed_precision .experimental .set_policy ('mixed_bfloat16' )
77
- else :
78
- tf .keras .mixed_precision .set_global_policy ('mixed_bfloat16' )
51
+ tf .keras .mixed_precision .set_global_policy ('mixed_bfloat16' )
79
52
elif dtype == tf .float32 :
80
- if use_experimental_api :
81
- tf .keras .mixed_precision .experimental .set_policy ('float32' )
82
- else :
83
- tf .keras .mixed_precision .set_global_policy ('float32' )
53
+ tf .keras .mixed_precision .set_global_policy ('float32' )
84
54
else :
85
55
raise ValueError ('Unexpected dtype: %s' % dtype )
0 commit comments