@@ -535,6 +535,7 @@ class HvdAllToAllEmbedding(BasicEmbedding):
535
535
536
536
def __init__ (self ,
537
537
with_unique = True ,
538
+ with_secondary_unique = True ,
538
539
mpi_size = None ,
539
540
batch_size = None ,
540
541
* args ,
@@ -547,6 +548,7 @@ def __init__(self,
547
548
)
548
549
self .hvd = hvd
549
550
self .with_unique = with_unique
551
+ self .with_secondary_unique = with_secondary_unique
550
552
self .batch_size = batch_size
551
553
if mpi_size is None :
552
554
self ._mpi_size = self .hvd .size ()
@@ -605,7 +607,14 @@ def __alltoall_embedding_lookup__(self, ids):
605
607
reloc_ids , remote_sizes , gather_indices = self .__relocate_dense_feature__ (
606
608
ids , batch_size = batch_size_runtime )
607
609
608
- lookup_result = de .shadow_ops .embedding_lookup (self .shadow , reloc_ids )
610
+ if self .with_secondary_unique :
611
+ with tf .name_scope (self .name + "/EmbeddingWithUnique" ):
612
+ reloc_unique_ids , reloc_unique_idx = tf .unique (reloc_ids )
613
+ reloc_unique_embeddings = de .shadow_ops .embedding_lookup (
614
+ self .shadow , reloc_unique_ids )
615
+ lookup_result = tf .gather (reloc_unique_embeddings , reloc_unique_idx )
616
+ else :
617
+ lookup_result = de .shadow_ops .embedding_lookup (self .shadow , reloc_ids )
609
618
lookup_result , _ = self .hvd .alltoall (lookup_result , splits = remote_sizes )
610
619
611
620
recover_shape = tf .concat ((input_shape , (self .embedding_size ,)), axis = 0 )
0 commit comments