Skip to content

Commit a6b2e8f

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Integrate DeepFM to the benchmarking framework (#3169)
Summary: Pull Request resolved: #3169 Integrate DeepFM model to the benchmarking framework by adding the model generation Reviewed By: aliafzal Differential Revision: D77898618 fbshipit-source-id: 30f19f756b074df32e2984bb802917b85c4f1381
1 parent 042c7c0 commit a6b2e8f

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
TrainPipelineSemiSync,
4949
)
5050
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
51+
from torchrec.models.deepfm import SimpleDeepFMNNWrapper
5152
from torchrec.models.dlrm import DLRMWrapper
5253
from torchrec.modules.embedding_configs import EmbeddingBagConfig
5354
from torchrec.modules.embedding_modules import EmbeddingBagCollection
@@ -188,8 +189,16 @@ def generate_model(
188189
weighted_tables: List[EmbeddingBagConfig],
189190
dense_device: torch.device,
190191
) -> nn.Module:
191-
# TODO: Implement DeepFM model generation
192-
raise NotImplementedError("DeepFM model generation not yet implemented")
192+
# DeepFM only uses unweighted tables
193+
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
194+
195+
# Create and return SimpleDeepFMNN model
196+
return SimpleDeepFMNNWrapper(
197+
num_dense_features=self.num_float_features,
198+
embedding_bag_collection=ebc,
199+
hidden_layer_size=self.hidden_layer_size,
200+
deep_fm_dimension=self.deep_fm_dimension,
201+
)
193202

194203

195204
@dataclass

0 commit comments

Comments
 (0)