Skip to content

Commit ee6288b

Browse files
authored
Develop an API to get hooks for elastic optimizers (#2510)
* Develop an API to get hooks for the elastic optimizer * Add annotation
1 parent 4bc8a2a commit ee6288b

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

elasticai_api/tensorflow/optimizer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import tensorflow as tf
1818
from horovod.tensorflow import _LegacyOptimizer
1919

20+
optimizer_instances: list = []
21+
2022

2123
def complement_value_from_env_if_none(
2224
original_value, key, clz, default_value=None
@@ -27,6 +29,14 @@ def complement_value_from_env_if_none(
2729
return clz(os.environ.get(key, default_value))
2830

2931

32+
def get_adjust_backward_passes_hooks():
33+
hooks = []
34+
global optimizer_instances
35+
for opt in optimizer_instances:
36+
hooks.append(AdjustBackwardPassesPerStepHook(opt))
37+
return hooks
38+
39+
3040
class AdjustBackwardPassesPerStepHook(tf.train.SessionRunHook):
3141
"""
3242
Hooks that adjusts `backward_passer_per_step` according to
@@ -636,7 +646,7 @@ def DistributedOptimizer(
636646
hvd_max_size, "WORKER_NUM", int, 1
637647
)
638648
global_batch_count_per_step = hvd_max_size * backward_passes_per_step
639-
return _DistributedOptimizer(
649+
opt = _DistributedOptimizer(
640650
optimizer=optimizer,
641651
name=name,
642652
use_locking=use_locking,
@@ -651,6 +661,9 @@ def DistributedOptimizer(
651661
num_groups=num_groups,
652662
global_batch_count_per_step=global_batch_count_per_step,
653663
)
664+
global optimizer_instance
665+
optimizer_instances.append(opt)
666+
return opt
654667
elif isinstance(optimizer, tf.keras.optimizers.Optimizer):
655668
raise ValueError(
656669
"fixed_global_batch_size == True is not supported yet with Keras"

0 commit comments

Comments
 (0)