Skip to content

Commit 754e2e8

Browse files
MoFHekarhdong
authored andcommitted
[fix] Support TF CheckpointManager in 2.9.
User may create Checkpoint by passing keys "model", "optimizer" and so on. So root parameter in __init__ function may be None. We need to walk through all kwargs.
1 parent b078929 commit 754e2e8

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/horovod_sync_train_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def common_all_to_all_embedding_trainable_v2(self, base_opt, test_opt, name):
339339
bp_v2=False,
340340
kv_creator=kv_creator,
341341
name='all2all_emb')
342-
ckpt = de.train.DEHvdCheckpoint(new_base_model)
342+
ckpt = de.train.DEHvdCheckpoint(my_model=new_base_model)
343343
hvd.join() # Sync for avoiding files conflict
344344
ckpt.restore(tf.train.latest_checkpoint(save_dir + '/ckpt/'))
345345
new_a2aemb_size = new_base_model.layers[0].params.size()

tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,21 @@ def _filter_de_hvd_a2a_tw(var):
155155
return False
156156
return True
157157

158-
if _filter_de_hvd_a2a_tw(self.root):
159-
func(var.params, de_dir)
160-
if hasattr(self.root, 'variables'):
161-
for var in self.root.variables:
162-
if _filter_de_hvd_a2a_tw(var):
163-
func(var.params, de_dir)
158+
def _handle_model_or_variable(obj):
159+
if _filter_de_hvd_a2a_tw(obj):
160+
func(var.params, de_dir)
161+
if hasattr(obj, 'variables'):
162+
_iter = obj.variables() if callable(obj.variables) else obj.variables
163+
for var in _iter:
164+
if _filter_de_hvd_a2a_tw(var):
165+
func(var.params, de_dir)
166+
167+
if hasattr(self, 'root'):
168+
_handle_model_or_variable(self.root)
164169
if len(self._tmp_var_key_set):
165-
for var_key in self._tmp_var_key_set:
166-
var = getattr(self, var_key)
167-
if _filter_de_hvd_a2a_tw(var):
168-
func(var.params, de_dir)
170+
for obj_key in self._tmp_var_key_set:
171+
obj_var = getattr(self, obj_key)
172+
_handle_model_or_variable(obj_var)
169173

170174
def _de_hvd_write_fs_func(self, file_prefix, tf_write_func):
171175

0 commit comments

Comments
 (0)