Skip to content

Commit 59966f2

Browse files
MoFHekarhdong
authored andcommitted
[fix] Embedding call didn't return a single worker call function when hvd.size() is 0.
1 parent 754e2e8 commit 59966f2

File tree

1 file changed

+8
-1
lines changed
  • tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers

1 file changed

+8
-1
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@
3434
from tensorflow.python.framework import ops
3535
from tensorflow.python.eager import tape
3636
from tensorflow.python.ops.variables import VariableAggregation
37+
from tensorflow.python.platform import tf_logging
3738
try: # The data_structures has been moved to the new package in tf 2.11
3839
from tensorflow.python.trackable import data_structures
3940
except:
4041
from tensorflow.python.training.tracking import data_structures
4142

4243
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import DistributedVariableWrapper, TrainableWrapperDistributedPolicy
4344
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable import make_partition
45+
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names
4446

4547

4648
def _choose_reduce_method(combiner, sparse=False, segmented=False):
@@ -245,6 +247,7 @@ def __init__(self,
245247
TrainableWrapperDistributedPolicy(VariableAggregation.NONE))
246248
else:
247249
self.shadow = self.shadow_impl.as_list()[0]
250+
self.params._created_in_class = self # To facilitate access to the primitive class through params
248251
super(Embedding, self).__init__(name=name,
249252
trainable=trainable,
250253
dtype=value_dtype)
@@ -550,7 +553,11 @@ def __init__(self,
550553
else:
551554
self._mpi_size = mpi_size
552555
super(HvdAllToAllEmbedding, self).__init__(*args, **kwargs)
553-
self.params._created_in_class = self
556+
if type(self.params.saveable).__name__ not in de_fs_saveable_class_names:
557+
tf_logging.warning(
558+
"Please use FileSystemSaver in KVCreator when use HvdAllToAllEmbedding. "
559+
"It will allow TFRA save and restore KV files when Embedding tensor parallel in distributed training. "
560+
)
554561

555562
def __relocate_dense_feature__(self, ids, batch_size=None):
556563
"""

0 commit comments

Comments
 (0)