Skip to content

Commit 577d4bd

Browse files
joshuuuasufacebook-github-bot
authored andcommitted
Fix return results of forward method of DecoupleEmbeddingColletion (#2861)
Summary: Pull Request resolved: #2861 The decouple_di_pass will replace ManagedCollisionEmbeddingCollection with DecoupleEmbeddingColletion. The forward method of the former returns a tuple of the embedding results and the features inself (Tuple[Dict[str, KJT], KJT]), but the latter only returns the embedding results (Dict[str, KJT]). This caused incompatibility issues in following TGIF transform passes, namely QuantizationPass. Reviewed By: kausv Differential Revision: D71412636 fbshipit-source-id: e1489422ace0e3d3cb88e0aa040b5c1c09ed65ad
1 parent 67ffc20 commit 577d4bd

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

torchrec/distributed/quant_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,9 @@ def __init__(
11891189
def forward(
11901190
self,
11911191
features: KeyedJaggedTensor,
1192-
) -> torch.Tensor:
1192+
) -> Tuple[
1193+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
1194+
]:
11931195
remapped_kjt = self._mcc_remapper(features)
11941196
return self._ec_lookup(remapped_kjt)
11951197

torchrec/quant/embedding_modules.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,15 +1091,16 @@ def to(
10911091
)
10921092
return self
10931093

1094+
# pyre-ignore
10941095
def forward(
10951096
self,
10961097
features: KeyedJaggedTensor,
1097-
) -> Dict[str, JaggedTensor]:
1098+
) -> Tuple[
1099+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
1100+
]:
10981101
features = self._managed_collision_collection(features)
10991102

1100-
# mcec expects Tuple return type
1101-
# pyre-ignore
1102-
return (super().forward(features),)
1103+
return (super().forward(features), features)
11031104

11041105
def _get_name(self) -> str:
11051106
return "QuantManagedCollisionEmbeddingCollection"

0 commit comments

Comments
 (0)