File tree Expand file tree Collapse file tree 2 files changed +24
-31
lines changed
tensorflow_recommenders_addons/dynamic_embedding/python/ops Expand file tree Collapse file tree 2 files changed +24
-31
lines changed Original file line number Diff line number Diff line change 14
14
# ==============================================================================
15
15
"""CuckooHash Lookup operations."""
16
16
# pylint: disable=g-bad-name
17
- from __future__ import absolute_import
18
- from __future__ import division
19
- from __future__ import print_function
20
17
21
18
import copy
22
19
import functools
@@ -127,11 +124,6 @@ def __init__(
127
124
)
128
125
if not context .executing_eagerly ():
129
126
ops .add_to_collection (ops .GraphKeys .SAVEABLE_OBJECTS , self .saveable )
130
- else :
131
- if shard_saveable_object_fn :
132
- self ._saveable_fn = shard_saveable_object_fn
133
- else :
134
- self ._saveable_fn = CuckooHashTable ._Saveable
135
127
136
128
def _create_resource (self ):
137
129
# The table must be shared if checkpointing is requested for multi-worker
@@ -440,15 +432,18 @@ def _gather_saveables_for_checkpoint(self):
440
432
# full_name helps to figure out the name-based Saver's name for this saveable.
441
433
full_name = self ._table_name
442
434
self ._new_obj_trackable = None # reset _new_obj_trackable when save again
443
- return {
444
- "table" :
445
- functools .partial (
446
- self ._saveable_fn ,
447
- table = self ,
448
- name = self ._name ,
449
- full_name = full_name ,
450
- )
451
- }
435
+ if self ._checkpoint :
436
+ return {
437
+ "table" :
438
+ functools .partial (
439
+ self ._saveable_fn ,
440
+ table = self ,
441
+ name = self ._name ,
442
+ full_name = full_name ,
443
+ )
444
+ }
445
+ else :
446
+ return {}
452
447
453
448
class _Saveable (BaseSaverBuilder .SaveableObject ):
454
449
"""SaveableObject implementation for CuckooHashTable."""
Original file line number Diff line number Diff line change @@ -254,11 +254,6 @@ def __init__(
254
254
)
255
255
if not context .executing_eagerly ():
256
256
ops .add_to_collection (ops .GraphKeys .SAVEABLE_OBJECTS , self .saveable )
257
- else :
258
- if shard_saveable_object_fn :
259
- self ._saveable_fn = shard_saveable_object_fn
260
- else :
261
- self ._saveable_fn = RedisTable ._Saveable
262
257
263
258
def _create_resource (self ):
264
259
# The table must be shared if checkpointing is requested for multi-worker
@@ -579,15 +574,18 @@ def _gather_saveables_for_checkpoint(self):
579
574
# full_name helps to figure out the name-based Saver's name for this saveable.
580
575
full_name = self ._table_name
581
576
self ._new_obj_trackable = None # reset _new_obj_trackable when save again
582
- return {
583
- "table" :
584
- functools .partial (
585
- self ._saveable_fn ,
586
- table = self ,
587
- name = self ._name ,
588
- full_name = full_name ,
589
- )
590
- }
577
+ if self ._checkpoint :
578
+ return {
579
+ "table" :
580
+ functools .partial (
581
+ self ._saveable_fn ,
582
+ table = self ,
583
+ name = self ._name ,
584
+ full_name = full_name ,
585
+ )
586
+ }
587
+ else :
588
+ return {}
591
589
592
590
class _Saveable (BaseSaverBuilder .SaveableObject ):
593
591
"""SaveableObject implementation for RedisTable."""
You can’t perform that action at this time.
0 commit comments