Skip to content

Commit 1bc4786

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Unified benchmark runner with YAML config
Summary: - Implemented unified benchmark runner with YAML configuration. - Created main benchmark executable unified_benchmark_runner.py. - Added example .YAML config files for both module and pipeline level benchmarks. Differential Revision: D79701597
1 parent 8e3054e commit 1bc4786

File tree

5 files changed

+356
-47
lines changed

5 files changed

+356
-47
lines changed

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 156 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,22 @@
2020
See benchmark_pipeline_utils.py for step-by-step instructions.
2121
"""
2222

23+
import importlib
2324
from dataclasses import dataclass, field
24-
from typing import Dict, List, Optional, Type, Union
25+
from typing import Any, Dict, List, Optional, Type
2526

2627
import torch
2728
from fbgemm_gpu.split_embedding_configs import EmbOptimType
2829
from torch import nn
2930
from torchrec.distributed.benchmark.benchmark_pipeline_utils import (
3031
BaseModelConfig,
3132
create_model_config,
32-
DeepFMConfig,
33-
DLRMConfig,
3433
generate_data,
3534
generate_pipeline,
36-
TestSparseNNConfig,
37-
TestTowerCollectionSparseNNConfig,
38-
TestTowerSparseNNConfig,
3935
)
4036
from torchrec.distributed.benchmark.benchmark_utils import (
4137
benchmark_func,
38+
benchmark_module,
4239
BenchmarkResult,
4340
cmd_conf,
4441
CPUMemoryStats,
@@ -62,6 +59,18 @@
6259
from torchrec.modules.embedding_configs import EmbeddingBagConfig
6360

6461

62+
@dataclass
63+
class UnifiedBenchmarkConfig:
64+
"""Unified configuration for both pipeline and module benchmarking."""
65+
66+
benchmark_type: str = "pipeline" # "pipeline" or "module"
67+
68+
# Module benchmarking specific options
69+
module_path: str = "" # e.g., "torchrec.models.deepfm"
70+
module_class: str = "" # e.g., "SimpleDeepFMNNWrapper"
71+
module_kwargs: Dict[str, Any] = field(default_factory=dict)
72+
73+
6574
@dataclass
6675
class RunOptions:
6776
"""
@@ -201,57 +210,160 @@ class ModelSelectionConfig:
201210
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])
202211

203212

204-
@cmd_conf
205-
def main(
206-
run_option: RunOptions,
213+
def dynamic_import_module(module_path: str, module_class: str) -> Type[nn.Module]:
214+
"""Dynamically import a module class from a given path."""
215+
try:
216+
module = importlib.import_module(module_path)
217+
return getattr(module, module_class)
218+
except (ImportError, AttributeError) as e:
219+
raise RuntimeError(f"Failed to import {module_class} from {module_path}: {e}")
220+
221+
222+
def create_module_instance(
223+
unified_config: UnifiedBenchmarkConfig,
224+
tables: List[EmbeddingBagConfig],
225+
weighted_tables: List[EmbeddingBagConfig],
207226
table_config: EmbeddingTablesConfig,
208-
model_selection: ModelSelectionConfig,
209-
pipeline_config: PipelineConfig,
210-
model_config: Optional[BaseModelConfig] = None,
211-
) -> None:
227+
) -> nn.Module:
228+
"""Create a module instance based on the unified config."""
229+
ModuleClass = dynamic_import_module(
230+
unified_config.module_path, unified_config.module_class
231+
)
232+
233+
# Handle common module instantiation patterns
234+
if unified_config.module_class == "SimpleDeepFMNNWrapper":
235+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
236+
237+
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
238+
return ModuleClass(
239+
embedding_bag_collection=ebc,
240+
num_dense_features=10, # Default value, can be overridden via module_kwargs
241+
**unified_config.module_kwargs,
242+
)
243+
elif unified_config.module_class == "DLRMWrapper":
244+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
245+
246+
ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta"))
247+
return ModuleClass(
248+
embedding_bag_collection=ebc,
249+
dense_in_features=10, # Default value, can be overridden via module_kwargs
250+
dense_arch_layer_sizes=[20, 128], # Default value
251+
over_arch_layer_sizes=[5, 1], # Default value
252+
**unified_config.module_kwargs,
253+
)
254+
elif unified_config.module_class == "EmbeddingBagCollection":
255+
return ModuleClass(tables=tables, **unified_config.module_kwargs)
256+
else:
257+
# Generic instantiation - try with tables and weighted_tables
258+
try:
259+
return ModuleClass(
260+
tables=tables,
261+
weighted_tables=weighted_tables,
262+
**unified_config.module_kwargs,
263+
)
264+
except TypeError:
265+
# Fallback to just tables
266+
try:
267+
return ModuleClass(tables=tables, **unified_config.module_kwargs)
268+
except TypeError:
269+
# Fallback to no embedding tables
270+
return ModuleClass(**unified_config.module_kwargs)
271+
272+
273+
def run_module_benchmark(
274+
unified_config: UnifiedBenchmarkConfig,
275+
table_config: EmbeddingTablesConfig,
276+
run_option: RunOptions,
277+
) -> BenchmarkResult:
278+
"""Run module-level benchmarking."""
212279
tables, weighted_tables = generate_tables(
213280
num_unweighted_features=table_config.num_unweighted_features,
214281
num_weighted_features=table_config.num_weighted_features,
215282
embedding_feature_dim=table_config.embedding_feature_dim,
216283
)
217284

218-
if model_config is None:
219-
model_config = create_model_config(
220-
model_name=model_selection.model_name,
221-
batch_size=model_selection.batch_size,
222-
batch_sizes=model_selection.batch_sizes,
223-
num_float_features=model_selection.num_float_features,
224-
feature_pooling_avg=model_selection.feature_pooling_avg,
225-
use_offsets=model_selection.use_offsets,
226-
dev_str=model_selection.dev_str,
227-
long_kjt_indices=model_selection.long_kjt_indices,
228-
long_kjt_offsets=model_selection.long_kjt_offsets,
229-
long_kjt_lengths=model_selection.long_kjt_lengths,
230-
pin_memory=model_selection.pin_memory,
231-
embedding_groups=model_selection.embedding_groups,
232-
feature_processor_modules=model_selection.feature_processor_modules,
233-
max_feature_lengths=model_selection.max_feature_lengths,
234-
over_arch_clazz=model_selection.over_arch_clazz,
235-
postproc_module=model_selection.postproc_module,
236-
zch=model_selection.zch,
237-
hidden_layer_size=model_selection.hidden_layer_size,
238-
deep_fm_dimension=model_selection.deep_fm_dimension,
239-
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
240-
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
241-
)
285+
module = create_module_instance(
286+
unified_config, tables, weighted_tables, table_config
287+
)
242288

243-
# launch trainers
244-
run_multi_process_func(
245-
func=runner,
246-
world_size=run_option.world_size,
289+
return benchmark_module(
290+
module=module,
247291
tables=tables,
248292
weighted_tables=weighted_tables,
249-
run_option=run_option,
250-
model_config=model_config,
251-
pipeline_config=pipeline_config,
293+
num_float_features=10, # Default value
294+
sharding_type=run_option.sharding_type,
295+
planner_type=run_option.planner_type,
296+
world_size=run_option.world_size,
297+
num_benchmarks=5, # Default value
298+
batch_size=2048, # Default value
299+
compute_kernel=run_option.compute_kernel,
300+
device_type="cuda",
252301
)
253302

254303

304+
@cmd_conf
305+
def main(
306+
run_option: RunOptions,
307+
table_config: EmbeddingTablesConfig,
308+
model_selection: ModelSelectionConfig,
309+
pipeline_config: PipelineConfig,
310+
unified_config: UnifiedBenchmarkConfig,
311+
model_config: Optional[BaseModelConfig] = None,
312+
) -> None:
313+
# Route to appropriate benchmark type based on unified config
314+
if unified_config.benchmark_type == "module":
315+
print("Running module-level benchmark...")
316+
result = run_module_benchmark(unified_config, table_config, run_option)
317+
print(f"Module benchmark completed: {result}")
318+
elif unified_config.benchmark_type == "pipeline":
319+
print("Running pipeline-level benchmark...")
320+
tables, weighted_tables = generate_tables(
321+
num_unweighted_features=table_config.num_unweighted_features,
322+
num_weighted_features=table_config.num_weighted_features,
323+
embedding_feature_dim=table_config.embedding_feature_dim,
324+
)
325+
326+
if model_config is None:
327+
model_config = create_model_config(
328+
model_name=model_selection.model_name,
329+
batch_size=model_selection.batch_size,
330+
batch_sizes=model_selection.batch_sizes,
331+
num_float_features=model_selection.num_float_features,
332+
feature_pooling_avg=model_selection.feature_pooling_avg,
333+
use_offsets=model_selection.use_offsets,
334+
dev_str=model_selection.dev_str,
335+
long_kjt_indices=model_selection.long_kjt_indices,
336+
long_kjt_offsets=model_selection.long_kjt_offsets,
337+
long_kjt_lengths=model_selection.long_kjt_lengths,
338+
pin_memory=model_selection.pin_memory,
339+
embedding_groups=model_selection.embedding_groups,
340+
feature_processor_modules=model_selection.feature_processor_modules,
341+
max_feature_lengths=model_selection.max_feature_lengths,
342+
over_arch_clazz=model_selection.over_arch_clazz,
343+
postproc_module=model_selection.postproc_module,
344+
zch=model_selection.zch,
345+
hidden_layer_size=model_selection.hidden_layer_size,
346+
deep_fm_dimension=model_selection.deep_fm_dimension,
347+
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
348+
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
349+
)
350+
351+
# launch trainers
352+
run_multi_process_func(
353+
func=runner,
354+
world_size=run_option.world_size,
355+
tables=tables,
356+
weighted_tables=weighted_tables,
357+
run_option=run_option,
358+
model_config=model_config,
359+
pipeline_config=pipeline_config,
360+
)
361+
else:
362+
raise ValueError(
363+
f"Unknown benchmark_type: {unified_config.benchmark_type}. Must be 'module' or 'pipeline'"
364+
)
365+
366+
255367
def run_pipeline(
256368
run_option: RunOptions,
257369
table_config: EmbeddingTablesConfig,

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import resource
2323
import time
2424
import timeit
25-
from dataclasses import dataclass, fields, is_dataclass, MISSING
25+
from dataclasses import dataclass, field, fields, is_dataclass, MISSING
2626
from enum import Enum
2727
from typing import (
2828
Any,
@@ -140,6 +140,25 @@ def __str__(self) -> str:
140140
return f"Rank {self.rank}: CPU Memory Peak RSS: {self.peak_rss_mbs/1000:.2f} GB"
141141

142142

143+
@dataclass
144+
class ModuleBenchmarkConfig:
145+
"""Configuration for module-level benchmarking."""
146+
147+
module_path: str = "" # e.g., "torchrec.models.deepfm"
148+
module_class: str = "" # e.g., "SimpleDeepFMNNWrapper"
149+
module_kwargs: Dict[str, Any] = field(
150+
default_factory=dict
151+
) # Additional kwargs for module instantiation
152+
num_float_features: int = 0
153+
sharding_type: ShardingType = ShardingType.TABLE_WISE
154+
planner_type: str = "embedding"
155+
world_size: int = 2
156+
num_benchmarks: int = 5
157+
batch_size: int = 2048
158+
compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED
159+
device_type: str = "cuda"
160+
161+
143162
@dataclass
144163
class BenchmarkResult:
145164
"Class for holding results of benchmark runs"
@@ -728,8 +747,9 @@ def _init_module_and_run_benchmark(
728747
def _func_to_benchmark(
729748
model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor]
730749
) -> None:
731-
for bench_input in bench_inputs:
732-
model(bench_input)
750+
with torch.inference_mode():
751+
for bench_input in bench_inputs:
752+
model(bench_input)
733753

734754
name = f"{sharding_type.value}-{planner_type}"
735755

@@ -828,6 +848,7 @@ def benchmark_module(
828848
if weighted_tables is None:
829849
weighted_tables = []
830850

851+
# Use multiprocessing for distributed benchmarking (always assume train mode)
831852
res = multi_process_benchmark(
832853
callable=_init_module_and_run_benchmark,
833854
module=module,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Example YAML configuration for module-level benchmarking
2+
# Usage: python -m torchrec.distributed.benchmark.unified_benchmark_runner --yaml_config=module_benchmark_config.yaml
3+
4+
UnifiedBenchmarkConfig:
5+
benchmark_type: "module"
6+
module_path: "torchrec.models.deepfm"
7+
module_class: "SimpleDeepFMNNWrapper"
8+
module_kwargs:
9+
hidden_layer_size: 20
10+
deep_fm_dimension: 5
11+
12+
RunOptions:
13+
world_size: 2
14+
num_batches: 10
15+
sharding_type: "table_wise"
16+
compute_kernel: "fused"
17+
input_type: "kjt"
18+
planner_type: "embedding"
19+
dense_optimizer: "SGD"
20+
dense_lr: 0.1
21+
sparse_optimizer: "EXACT_ADAGRAD"
22+
sparse_lr: 0.1
23+
24+
EmbeddingTablesConfig:
25+
num_unweighted_features: 100
26+
num_weighted_features: 100
27+
embedding_feature_dim: 128
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Example YAML configuration for pipeline-level benchmarking
2+
# Usage: python -m torchrec.distributed.benchmark.unified_benchmark_runner --yaml_config=pipeline_benchmark_config.yaml
3+
4+
UnifiedBenchmarkConfig:
5+
benchmark_type: "pipeline"
6+
7+
PipelineConfig:
8+
pipeline: "sparse"
9+
emb_lookup_stream: "data_dist"
10+
apply_jit: false
11+
12+
RunOptions:
13+
world_size: 2
14+
num_batches: 10
15+
sharding_type: "table_wise"
16+
compute_kernel: "fused"
17+
input_type: "kjt"
18+
planner_type: "embedding"
19+
dense_optimizer: "SGD"
20+
dense_lr: 0.1
21+
sparse_optimizer: "EXACT_ADAGRAD"
22+
sparse_lr: 0.1
23+
24+
ModelSelectionConfig:
25+
model_name: "test_sparse_nn"
26+
batch_size: 8192
27+
num_float_features: 10
28+
feature_pooling_avg: 10
29+
use_offsets: false
30+
long_kjt_indices: true
31+
long_kjt_offsets: true
32+
long_kjt_lengths: true
33+
pin_memory: true
34+
35+
EmbeddingTablesConfig:
36+
num_unweighted_features: 100
37+
num_weighted_features: 100
38+
embedding_feature_dim: 128

0 commit comments

Comments
 (0)