|
20 | 20 | See benchmark_pipeline_utils.py for step-by-step instructions.
|
21 | 21 | """
|
22 | 22 |
|
| 23 | +import importlib |
23 | 24 | from dataclasses import dataclass, field
|
24 |
| -from typing import Dict, List, Optional, Type, Union |
| 25 | +from typing import Any, Dict, List, Optional, Type |
25 | 26 |
|
26 | 27 | import torch
|
27 | 28 | from fbgemm_gpu.split_embedding_configs import EmbOptimType
|
28 | 29 | from torch import nn
|
29 | 30 | from torchrec.distributed.benchmark.benchmark_pipeline_utils import (
|
30 | 31 | BaseModelConfig,
|
31 | 32 | create_model_config,
|
32 |
| - DeepFMConfig, |
33 |
| - DLRMConfig, |
34 | 33 | generate_data,
|
35 | 34 | generate_pipeline,
|
36 |
| - TestSparseNNConfig, |
37 |
| - TestTowerCollectionSparseNNConfig, |
38 |
| - TestTowerSparseNNConfig, |
39 | 35 | )
|
40 | 36 | from torchrec.distributed.benchmark.benchmark_utils import (
|
41 | 37 | benchmark_func,
|
| 38 | + benchmark_module, |
42 | 39 | BenchmarkResult,
|
43 | 40 | cmd_conf,
|
44 | 41 | CPUMemoryStats,
|
|
62 | 59 | from torchrec.modules.embedding_configs import EmbeddingBagConfig
|
63 | 60 |
|
64 | 61 |
|
| 62 | +@dataclass |
| 63 | +class UnifiedBenchmarkConfig: |
| 64 | + """Unified configuration for both pipeline and module benchmarking.""" |
| 65 | + |
| 66 | + benchmark_type: str = "pipeline" # "pipeline" or "module" |
| 67 | + |
| 68 | + # Module benchmarking specific options |
| 69 | + module_path: str = "" # e.g., "torchrec.models.deepfm" |
| 70 | + module_class: str = "" # e.g., "SimpleDeepFMNNWrapper" |
| 71 | + module_kwargs: Dict[str, Any] = field(default_factory=dict) |
| 72 | + |
| 73 | + |
65 | 74 | @dataclass
|
66 | 75 | class RunOptions:
|
67 | 76 | """
|
@@ -201,57 +210,160 @@ class ModelSelectionConfig:
|
201 | 210 | over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])
|
202 | 211 |
|
203 | 212 |
|
204 |
| -@cmd_conf |
205 |
| -def main( |
206 |
| - run_option: RunOptions, |
| 213 | +def dynamic_import_module(module_path: str, module_class: str) -> Type[nn.Module]: |
| 214 | + """Dynamically import a module class from a given path.""" |
| 215 | + try: |
| 216 | + module = importlib.import_module(module_path) |
| 217 | + return getattr(module, module_class) |
| 218 | + except (ImportError, AttributeError) as e: |
| 219 | + raise RuntimeError(f"Failed to import {module_class} from {module_path}: {e}") |
| 220 | + |
| 221 | + |
| 222 | +def create_module_instance( |
| 223 | + unified_config: UnifiedBenchmarkConfig, |
| 224 | + tables: List[EmbeddingBagConfig], |
| 225 | + weighted_tables: List[EmbeddingBagConfig], |
207 | 226 | table_config: EmbeddingTablesConfig,
|
208 |
| - model_selection: ModelSelectionConfig, |
209 |
| - pipeline_config: PipelineConfig, |
210 |
| - model_config: Optional[BaseModelConfig] = None, |
211 |
| -) -> None: |
| 227 | +) -> nn.Module: |
| 228 | + """Create a module instance based on the unified config.""" |
| 229 | + ModuleClass = dynamic_import_module( |
| 230 | + unified_config.module_path, unified_config.module_class |
| 231 | + ) |
| 232 | + |
| 233 | + # Handle common module instantiation patterns |
| 234 | + if unified_config.module_class == "SimpleDeepFMNNWrapper": |
| 235 | + from torchrec.modules.embedding_modules import EmbeddingBagCollection |
| 236 | + |
| 237 | + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) |
| 238 | + return ModuleClass( |
| 239 | + embedding_bag_collection=ebc, |
| 240 | + num_dense_features=10, # Default value, can be overridden via module_kwargs |
| 241 | + **unified_config.module_kwargs, |
| 242 | + ) |
| 243 | + elif unified_config.module_class == "DLRMWrapper": |
| 244 | + from torchrec.modules.embedding_modules import EmbeddingBagCollection |
| 245 | + |
| 246 | + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) |
| 247 | + return ModuleClass( |
| 248 | + embedding_bag_collection=ebc, |
| 249 | + dense_in_features=10, # Default value, can be overridden via module_kwargs |
| 250 | + dense_arch_layer_sizes=[20, 128], # Default value |
| 251 | + over_arch_layer_sizes=[5, 1], # Default value |
| 252 | + **unified_config.module_kwargs, |
| 253 | + ) |
| 254 | + elif unified_config.module_class == "EmbeddingBagCollection": |
| 255 | + return ModuleClass(tables=tables, **unified_config.module_kwargs) |
| 256 | + else: |
| 257 | + # Generic instantiation - try with tables and weighted_tables |
| 258 | + try: |
| 259 | + return ModuleClass( |
| 260 | + tables=tables, |
| 261 | + weighted_tables=weighted_tables, |
| 262 | + **unified_config.module_kwargs, |
| 263 | + ) |
| 264 | + except TypeError: |
| 265 | + # Fallback to just tables |
| 266 | + try: |
| 267 | + return ModuleClass(tables=tables, **unified_config.module_kwargs) |
| 268 | + except TypeError: |
| 269 | + # Fallback to no embedding tables |
| 270 | + return ModuleClass(**unified_config.module_kwargs) |
| 271 | + |
| 272 | + |
| 273 | +def run_module_benchmark( |
| 274 | + unified_config: UnifiedBenchmarkConfig, |
| 275 | + table_config: EmbeddingTablesConfig, |
| 276 | + run_option: RunOptions, |
| 277 | +) -> BenchmarkResult: |
| 278 | + """Run module-level benchmarking.""" |
212 | 279 | tables, weighted_tables = generate_tables(
|
213 | 280 | num_unweighted_features=table_config.num_unweighted_features,
|
214 | 281 | num_weighted_features=table_config.num_weighted_features,
|
215 | 282 | embedding_feature_dim=table_config.embedding_feature_dim,
|
216 | 283 | )
|
217 | 284 |
|
218 |
| - if model_config is None: |
219 |
| - model_config = create_model_config( |
220 |
| - model_name=model_selection.model_name, |
221 |
| - batch_size=model_selection.batch_size, |
222 |
| - batch_sizes=model_selection.batch_sizes, |
223 |
| - num_float_features=model_selection.num_float_features, |
224 |
| - feature_pooling_avg=model_selection.feature_pooling_avg, |
225 |
| - use_offsets=model_selection.use_offsets, |
226 |
| - dev_str=model_selection.dev_str, |
227 |
| - long_kjt_indices=model_selection.long_kjt_indices, |
228 |
| - long_kjt_offsets=model_selection.long_kjt_offsets, |
229 |
| - long_kjt_lengths=model_selection.long_kjt_lengths, |
230 |
| - pin_memory=model_selection.pin_memory, |
231 |
| - embedding_groups=model_selection.embedding_groups, |
232 |
| - feature_processor_modules=model_selection.feature_processor_modules, |
233 |
| - max_feature_lengths=model_selection.max_feature_lengths, |
234 |
| - over_arch_clazz=model_selection.over_arch_clazz, |
235 |
| - postproc_module=model_selection.postproc_module, |
236 |
| - zch=model_selection.zch, |
237 |
| - hidden_layer_size=model_selection.hidden_layer_size, |
238 |
| - deep_fm_dimension=model_selection.deep_fm_dimension, |
239 |
| - dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, |
240 |
| - over_arch_layer_sizes=model_selection.over_arch_layer_sizes, |
241 |
| - ) |
| 285 | + module = create_module_instance( |
| 286 | + unified_config, tables, weighted_tables, table_config |
| 287 | + ) |
242 | 288 |
|
243 |
| - # launch trainers |
244 |
| - run_multi_process_func( |
245 |
| - func=runner, |
246 |
| - world_size=run_option.world_size, |
| 289 | + return benchmark_module( |
| 290 | + module=module, |
247 | 291 | tables=tables,
|
248 | 292 | weighted_tables=weighted_tables,
|
249 |
| - run_option=run_option, |
250 |
| - model_config=model_config, |
251 |
| - pipeline_config=pipeline_config, |
| 293 | + num_float_features=10, # Default value |
| 294 | + sharding_type=run_option.sharding_type, |
| 295 | + planner_type=run_option.planner_type, |
| 296 | + world_size=run_option.world_size, |
| 297 | + num_benchmarks=5, # Default value |
| 298 | + batch_size=2048, # Default value |
| 299 | + compute_kernel=run_option.compute_kernel, |
| 300 | + device_type="cuda", |
252 | 301 | )
|
253 | 302 |
|
254 | 303 |
|
| 304 | +@cmd_conf |
| 305 | +def main( |
| 306 | + run_option: RunOptions, |
| 307 | + table_config: EmbeddingTablesConfig, |
| 308 | + model_selection: ModelSelectionConfig, |
| 309 | + pipeline_config: PipelineConfig, |
| 310 | + unified_config: UnifiedBenchmarkConfig, |
| 311 | + model_config: Optional[BaseModelConfig] = None, |
| 312 | +) -> None: |
| 313 | + # Route to appropriate benchmark type based on unified config |
| 314 | + if unified_config.benchmark_type == "module": |
| 315 | + print("Running module-level benchmark...") |
| 316 | + result = run_module_benchmark(unified_config, table_config, run_option) |
| 317 | + print(f"Module benchmark completed: {result}") |
| 318 | + elif unified_config.benchmark_type == "pipeline": |
| 319 | + print("Running pipeline-level benchmark...") |
| 320 | + tables, weighted_tables = generate_tables( |
| 321 | + num_unweighted_features=table_config.num_unweighted_features, |
| 322 | + num_weighted_features=table_config.num_weighted_features, |
| 323 | + embedding_feature_dim=table_config.embedding_feature_dim, |
| 324 | + ) |
| 325 | + |
| 326 | + if model_config is None: |
| 327 | + model_config = create_model_config( |
| 328 | + model_name=model_selection.model_name, |
| 329 | + batch_size=model_selection.batch_size, |
| 330 | + batch_sizes=model_selection.batch_sizes, |
| 331 | + num_float_features=model_selection.num_float_features, |
| 332 | + feature_pooling_avg=model_selection.feature_pooling_avg, |
| 333 | + use_offsets=model_selection.use_offsets, |
| 334 | + dev_str=model_selection.dev_str, |
| 335 | + long_kjt_indices=model_selection.long_kjt_indices, |
| 336 | + long_kjt_offsets=model_selection.long_kjt_offsets, |
| 337 | + long_kjt_lengths=model_selection.long_kjt_lengths, |
| 338 | + pin_memory=model_selection.pin_memory, |
| 339 | + embedding_groups=model_selection.embedding_groups, |
| 340 | + feature_processor_modules=model_selection.feature_processor_modules, |
| 341 | + max_feature_lengths=model_selection.max_feature_lengths, |
| 342 | + over_arch_clazz=model_selection.over_arch_clazz, |
| 343 | + postproc_module=model_selection.postproc_module, |
| 344 | + zch=model_selection.zch, |
| 345 | + hidden_layer_size=model_selection.hidden_layer_size, |
| 346 | + deep_fm_dimension=model_selection.deep_fm_dimension, |
| 347 | + dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, |
| 348 | + over_arch_layer_sizes=model_selection.over_arch_layer_sizes, |
| 349 | + ) |
| 350 | + |
| 351 | + # launch trainers |
| 352 | + run_multi_process_func( |
| 353 | + func=runner, |
| 354 | + world_size=run_option.world_size, |
| 355 | + tables=tables, |
| 356 | + weighted_tables=weighted_tables, |
| 357 | + run_option=run_option, |
| 358 | + model_config=model_config, |
| 359 | + pipeline_config=pipeline_config, |
| 360 | + ) |
| 361 | + else: |
| 362 | + raise ValueError( |
| 363 | + f"Unknown benchmark_type: {unified_config.benchmark_type}. Must be 'module' or 'pipeline'" |
| 364 | + ) |
| 365 | + |
| 366 | + |
255 | 367 | def run_pipeline(
|
256 | 368 | run_option: RunOptions,
|
257 | 369 | table_config: EmbeddingTablesConfig,
|
|
0 commit comments