|
34 | 34 | from tensorflow.python.framework import ops
|
35 | 35 | from tensorflow.python.eager import tape
|
36 | 36 | from tensorflow.python.ops.variables import VariableAggregation
|
| 37 | +from tensorflow.python.platform import tf_logging |
37 | 38 | try: # The data_structures has been moved to the new package in tf 2.11
|
38 | 39 | from tensorflow.python.trackable import data_structures
|
39 | 40 | except:
|
40 | 41 | from tensorflow.python.training.tracking import data_structures
|
41 | 42 |
|
42 | 43 | from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_ops import DistributedVariableWrapper, TrainableWrapperDistributedPolicy
|
43 | 44 | from tensorflow_recommenders_addons.dynamic_embedding.python.ops.dynamic_embedding_variable import make_partition
|
| 45 | +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names |
44 | 46 |
|
45 | 47 |
|
46 | 48 | def _choose_reduce_method(combiner, sparse=False, segmented=False):
|
@@ -245,6 +247,7 @@ def __init__(self,
|
245 | 247 | TrainableWrapperDistributedPolicy(VariableAggregation.NONE))
|
246 | 248 | else:
|
247 | 249 | self.shadow = self.shadow_impl.as_list()[0]
|
| 250 | + self.params._created_in_class = self # To facilitate access to the primitive class through params |
248 | 251 | super(Embedding, self).__init__(name=name,
|
249 | 252 | trainable=trainable,
|
250 | 253 | dtype=value_dtype)
|
@@ -550,7 +553,11 @@ def __init__(self,
|
550 | 553 | else:
|
551 | 554 | self._mpi_size = mpi_size
|
552 | 555 | super(HvdAllToAllEmbedding, self).__init__(*args, **kwargs)
|
553 |
| - self.params._created_in_class = self |
| 556 | + if type(self.params.saveable).__name__ not in de_fs_saveable_class_names: |
| 557 | + tf_logging.warning( |
| 558 | + "Please use FileSystemSaver in KVCreator when use HvdAllToAllEmbedding. " |
| 559 | + "It will allow TFRA save and restore KV files when Embedding tensor parallel in distributed training. " |
| 560 | + ) |
554 | 561 |
|
555 | 562 | def __relocate_dense_feature__(self, ids, batch_size=None):
|
556 | 563 | """
|
|
0 commit comments