Skip to content

Commit 24b29d1

Browse files
MoFHekarhdong
authored andcommitted
[fix] checkpoint setting is inoperative when calling _gather_saveables_for_checkpoint.
1 parent 0bd19d9 commit 24b29d1

File tree

2 files changed

+24
-31
lines changed

2 files changed

+24
-31
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/ops/cuckoo_hashtable_ops.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
# ==============================================================================
1515
"""CuckooHash Lookup operations."""
1616
# pylint: disable=g-bad-name
17-
from __future__ import absolute_import
18-
from __future__ import division
19-
from __future__ import print_function
2017

2118
import copy
2219
import functools
@@ -127,11 +124,6 @@ def __init__(
127124
)
128125
if not context.executing_eagerly():
129126
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self.saveable)
130-
else:
131-
if shard_saveable_object_fn:
132-
self._saveable_fn = shard_saveable_object_fn
133-
else:
134-
self._saveable_fn = CuckooHashTable._Saveable
135127

136128
def _create_resource(self):
137129
# The table must be shared if checkpointing is requested for multi-worker
@@ -440,15 +432,18 @@ def _gather_saveables_for_checkpoint(self):
440432
# full_name helps to figure out the name-based Saver's name for this saveable.
441433
full_name = self._table_name
442434
self._new_obj_trackable = None # reset _new_obj_trackable when save again
443-
return {
444-
"table":
445-
functools.partial(
446-
self._saveable_fn,
447-
table=self,
448-
name=self._name,
449-
full_name=full_name,
450-
)
451-
}
435+
if self._checkpoint:
436+
return {
437+
"table":
438+
functools.partial(
439+
self._saveable_fn,
440+
table=self,
441+
name=self._name,
442+
full_name=full_name,
443+
)
444+
}
445+
else:
446+
return {}
452447

453448
class _Saveable(BaseSaverBuilder.SaveableObject):
454449
"""SaveableObject implementation for CuckooHashTable."""

tensorflow_recommenders_addons/dynamic_embedding/python/ops/redis_table_ops.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,6 @@ def __init__(
254254
)
255255
if not context.executing_eagerly():
256256
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self.saveable)
257-
else:
258-
if shard_saveable_object_fn:
259-
self._saveable_fn = shard_saveable_object_fn
260-
else:
261-
self._saveable_fn = RedisTable._Saveable
262257

263258
def _create_resource(self):
264259
# The table must be shared if checkpointing is requested for multi-worker
@@ -579,15 +574,18 @@ def _gather_saveables_for_checkpoint(self):
579574
# full_name helps to figure out the name-based Saver's name for this saveable.
580575
full_name = self._table_name
581576
self._new_obj_trackable = None # reset _new_obj_trackable when save again
582-
return {
583-
"table":
584-
functools.partial(
585-
self._saveable_fn,
586-
table=self,
587-
name=self._name,
588-
full_name=full_name,
589-
)
590-
}
577+
if self._checkpoint:
578+
return {
579+
"table":
580+
functools.partial(
581+
self._saveable_fn,
582+
table=self,
583+
name=self._name,
584+
full_name=full_name,
585+
)
586+
}
587+
else:
588+
return {}
591589

592590
class _Saveable(BaseSaverBuilder.SaveableObject):
593591
"""SaveableObject implementation for RedisTable."""

0 commit comments

Comments
 (0)