Skip to content

Commit 042c7c0

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Integrate DLRM to the benchmarking framework (#3168)
Summary: Pull Request resolved: #3168 Integrate DLRM to the benchmarking framework and set the default arch layer sizes. The last dense arch layer size should be the same as `embedding_feature_dim` Reviewed By: aliafzal Differential Revision: D77896635 fbshipit-source-id: c6557d5e2d2ef2322a48f7daa78d98b6a9607a95
1 parent 568d1f8 commit 042c7c0

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848
TrainPipelineSemiSync,
4949
)
5050
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
51+
from torchrec.models.dlrm import DLRMWrapper
5152
from torchrec.modules.embedding_configs import EmbeddingBagConfig
53+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
5254

5355

5456
@dataclass
@@ -203,8 +205,16 @@ def generate_model(
203205
weighted_tables: List[EmbeddingBagConfig],
204206
dense_device: torch.device,
205207
) -> nn.Module:
206-
# TODO: Implement DLRM model generation
207-
raise NotImplementedError("DLRM model generation not yet implemented")
208+
# DLRM only uses unweighted tables
209+
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
210+
211+
return DLRMWrapper(
212+
embedding_bag_collection=ebc,
213+
dense_in_features=self.num_float_features,
214+
dense_arch_layer_sizes=self.dense_arch_layer_sizes,
215+
over_arch_layer_sizes=self.over_arch_layer_sizes,
216+
dense_device=dense_device,
217+
)
208218

209219

210220
# pyre-ignore[2]: Missing parameter annotation

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ class ModelSelectionConfig:
190190
deep_fm_dimension: int = 5
191191

192192
# DLRM specific config
193-
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 10])
194-
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 3])
193+
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128])
194+
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])
195195

196196

197197
@cmd_conf

0 commit comments

Comments
 (0)