17
17
import tensorflow as tf
18
18
from horovod .tensorflow import _LegacyOptimizer
19
19
20
+ optimizer_instances : list = []
21
+
20
22
21
23
def complement_value_from_env_if_none (
22
24
original_value , key , clz , default_value = None
@@ -27,6 +29,14 @@ def complement_value_from_env_if_none(
27
29
return clz (os .environ .get (key , default_value ))
28
30
29
31
32
+ def get_adjust_backward_passes_hooks ():
33
+ hooks = []
34
+ global optimizer_instances
35
+ for opt in optimizer_instances :
36
+ hooks .append (AdjustBackwardPassesPerStepHook (opt ))
37
+ return hooks
38
+
39
+
30
40
class AdjustBackwardPassesPerStepHook (tf .train .SessionRunHook ):
31
41
"""
32
42
Hooks that adjusts `backward_passer_per_step` according to
@@ -636,7 +646,7 @@ def DistributedOptimizer(
636
646
hvd_max_size , "WORKER_NUM" , int , 1
637
647
)
638
648
global_batch_count_per_step = hvd_max_size * backward_passes_per_step
639
- return _DistributedOptimizer (
649
+ opt = _DistributedOptimizer (
640
650
optimizer = optimizer ,
641
651
name = name ,
642
652
use_locking = use_locking ,
@@ -651,6 +661,9 @@ def DistributedOptimizer(
651
661
num_groups = num_groups ,
652
662
global_batch_count_per_step = global_batch_count_per_step ,
653
663
)
664
+ global optimizer_instance
665
+ optimizer_instances .append (opt )
666
+ return opt
654
667
elif isinstance (optimizer , tf .keras .optimizers .Optimizer ):
655
668
raise ValueError (
656
669
"fixed_global_batch_size == True is not supported yet with Keras"
0 commit comments