Skip to content

Commit 9515585

Browse files
MoFHekarhdong
authored andcommitted
[fix] The id key may be multiple copies in one tensor before real shadow lookup when using HvdAllToAllEmbedding.
1 parent 7a6bce1 commit 9515585

File tree

1 file changed

+10
-1
lines changed
  • tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers

1 file changed

+10
-1
lines changed

tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ class HvdAllToAllEmbedding(BasicEmbedding):
535535

536536
def __init__(self,
537537
with_unique=True,
538+
with_secondary_unique=True,
538539
mpi_size=None,
539540
batch_size=None,
540541
*args,
@@ -547,6 +548,7 @@ def __init__(self,
547548
)
548549
self.hvd = hvd
549550
self.with_unique = with_unique
551+
self.with_secondary_unique = with_secondary_unique
550552
self.batch_size = batch_size
551553
if mpi_size is None:
552554
self._mpi_size = self.hvd.size()
@@ -605,7 +607,14 @@ def __alltoall_embedding_lookup__(self, ids):
605607
reloc_ids, remote_sizes, gather_indices = self.__relocate_dense_feature__(
606608
ids, batch_size=batch_size_runtime)
607609

608-
lookup_result = de.shadow_ops.embedding_lookup(self.shadow, reloc_ids)
610+
if self.with_secondary_unique:
611+
with tf.name_scope(self.name + "/EmbeddingWithUnique"):
612+
reloc_unique_ids, reloc_unique_idx = tf.unique(reloc_ids)
613+
reloc_unique_embeddings = de.shadow_ops.embedding_lookup(
614+
self.shadow, reloc_unique_ids)
615+
lookup_result = tf.gather(reloc_unique_embeddings, reloc_unique_idx)
616+
else:
617+
lookup_result = de.shadow_ops.embedding_lookup(self.shadow, reloc_ids)
609618
lookup_result, _ = self.hvd.alltoall(lookup_result, splits=remote_sizes)
610619

611620
recover_shape = tf.concat((input_shape, (self.embedding_size,)), axis=0)

0 commit comments

Comments
 (0)