diff --git a/examples/zch/main.py b/examples/zch/main.py index 988bdade3..3d5768581 100644 --- a/examples/zch/main.py +++ b/examples/zch/main.py @@ -13,7 +13,6 @@ import torch from torchrec import EmbeddingConfig, KeyedJaggedTensor -from torchrec.distributed.benchmark.benchmark_utils import get_inputs from tqdm import tqdm from .sparse_arch import SparseArch diff --git a/torchrec/distributed/benchmark/benchmark_ebc.py b/torchrec/distributed/benchmark/benchmark_ebc.py new file mode 100644 index 000000000..225c24027 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_ebc.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[16] + +""" +Benchmark utilities specifically for EmbeddingBagCollection (EBC) and QuantEmbeddingBagCollection (QEBC) modules. + +This module contains functions for benchmarking EBC and QEBC modules with different sharding strategies and compilation modes. +""" + +import contextlib +import copy +import gc +import logging +import time +from typing import ( + Any, + Callable, + ContextManager, + Dict, + List, + Optional, + Tuple, + TypeVar, + Union, +) + +import torch +from torch import multiprocessing as mp +from torchrec.distributed import DistributedModelParallel +from torchrec.distributed.embedding_types import ShardingType +from torchrec.distributed.global_settings import set_propogate_device +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.planner.enumerators import EmbeddingEnumerator +from torchrec.distributed.planner.shard_estimators import ( + EmbeddingPerfEstimator, + EmbeddingStorageEstimator, +) +from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.test_utils.multi_process import MultiProcessContext +from torchrec.distributed.test_utils.test_model import ModelInput +from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv +from torchrec.fx import symbolic_trace +from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, + EmbeddingCollection as QuantEmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor + +# Import the shared types and utilities from benchmark_utils +from .benchmark_utils import ( + benchmark, + BenchmarkResult, + CompileMode, + multi_process_benchmark, +) + +logger: logging.Logger = logging.getLogger() + +T = TypeVar("T", bound=torch.nn.Module) + + +class ECWrapper(torch.nn.Module): + """ + Wrapper Module for benchmarking EC Modules + + Args: + module: module to benchmark + + Call Args: + input: KeyedJaggedTensor KJT input to module + + Returns: + output: KT output from module + + Example: + e1_config = EmbeddingConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + e2_config = EmbeddingConfig( + name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] + ) + + ec = EmbeddingCollection(tables=[e1_config, e2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + ec.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qec = QuantEmbeddingCollection.from_float(ecc) + + wrapped_module = ECWrapper(qec) + quantized_embeddings = wrapped_module(features) + """ + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + Dict[str, JaggedTensor] + """ + return self._module.forward(input) + + +class EBCWrapper(torch.nn.Module): + """ + Wrapper Module for benchmarking EBC Modules + + Args: + module: module to benchmark + + Call Args: + input: KeyedJaggedTensor KJT input to module + + Returns: + output: KT output from module + + Example: + table_0 = EmbeddingBagConfig( + name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] + ) + table_1 = EmbeddingBagConfig( + name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] + ) + ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + + features = KeyedJaggedTensor( + keys=["f1", "f2"], + values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), + ) + + ebc.qconfig = torch.quantization.QConfig( + activation=torch.quantization.PlaceholderObserver.with_args( + dtype=torch.qint8 + ), + weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), + ) + + qebc = QuantEmbeddingBagCollection.from_float(ebc) + + wrapped_module = EBCWrapper(qebc) + quantized_embeddings = wrapped_module(features) + """ + + def __init__(self, module: torch.nn.Module) -> None: + super().__init__() + self._module = module + + def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: + """ + Args: + input (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + KeyedTensor + """ + return self._module.forward(input) + + +def default_func_to_benchmark( + model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] +) -> None: + with torch.inference_mode(): + for bench_input in bench_inputs: + model(bench_input) + + +def get_tables( + table_sizes: List[Tuple[int, int]], + is_pooled: bool = True, + data_type: DataType = DataType.INT8, +) -> Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]: + if is_pooled: + tables: List[EmbeddingBagConfig] = [ + EmbeddingBagConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=data_type, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] + else: + tables: List[EmbeddingConfig] = [ + EmbeddingConfig( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + data_type=data_type, + ) + for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) + ] + + return tables + + +def get_inputs( + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], + batch_size: int, + world_size: int, + num_inputs: int, + train: bool, + pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, +) -> List[List[KeyedJaggedTensor]]: + inputs_batch: List[List[KeyedJaggedTensor]] = [] + + if variable_batch_embeddings and not train: + raise RuntimeError("Variable batch size is only supported in training mode") + + for _ in range(num_inputs): + if variable_batch_embeddings: + _, model_input_by_rank = ModelInput.generate_variable_batch_input( + average_batch_size=batch_size, + world_size=world_size, + num_float_features=0, + tables=tables, + ) + else: + _, model_input_by_rank = ModelInput.generate( + batch_size=batch_size, + world_size=world_size, + num_float_features=0, + tables=tables, + weighted_tables=[], + tables_pooling=pooling_configs, + indices_dtype=torch.int32, + lengths_dtype=torch.int32, + ) + + if train: + sparse_features_by_rank = [ + model_input.idlist_features + for model_input in model_input_by_rank + if isinstance(model_input.idlist_features, KeyedJaggedTensor) + ] + inputs_batch.append(sparse_features_by_rank) + else: + sparse_features = model_input_by_rank[0].idlist_features + assert isinstance(sparse_features, KeyedJaggedTensor) + inputs_batch.append([sparse_features]) + + # Transpose if train, as inputs_by_rank is currently in [B X R] format + inputs_by_rank = list(zip(*inputs_batch)) + + return inputs_by_rank + + +def transform_module( + module: torch.nn.Module, + device: torch.device, + inputs: List[KeyedJaggedTensor], + sharder: ModuleSharder[T], + sharding_type: ShardingType, + compile_mode: CompileMode, + world_size: int, + batch_size: int, + # pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter. + ctx: ContextManager, + benchmark_unsharded_module: bool = False, +) -> torch.nn.Module: + def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: + eager_module(inputs[0]) + graph_module = symbolic_trace( + eager_module, leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"] + ) + scripted_module = torch.jit.script(graph_module) + return scripted_module + + set_propogate_device(True) + + sharded_module = None + + if not benchmark_unsharded_module: + topology: Topology = Topology(world_size=world_size, compute_device=device.type) + planner = EmbeddingShardingPlanner( + topology=topology, + batch_size=batch_size, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), + ) + + # Don't want to modify the module outright + # Since module is on cpu, won't cause cuda oom. + copied_module = copy.deepcopy(module) + # pyre-ignore [6] + plan = planner.plan(copied_module, [sharder]) + + if isinstance(ctx, MultiProcessContext): + sharded_module = DistributedModelParallel( + copied_module, + # pyre-ignore[6] + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + else: + env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) + + sharded_module = _shard_modules( + module=copied_module, + # pyre-fixme[6]: For 2nd argument expected + # `Optional[List[ModuleSharder[Module]]]` but got + # `List[ModuleSharder[Variable[T (bound to Module)]]]`. + sharders=[sharder], + device=device, + plan=plan, + env=env, + ) + + if compile_mode == CompileMode.FX_SCRIPT: + return fx_script_module( + # pyre-fixme[6]: For 1st argument expected `Module` but got + # `Optional[Module]`. + sharded_module + if not benchmark_unsharded_module + else module + ) + else: + # pyre-fixme[7]: Expected `Module` but got `Optional[Module]`. + return sharded_module if not benchmark_unsharded_module else module + + +def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str: + if sharding_type == ShardingType.TABLE_WISE: + name = "tw-sharded" + elif sharding_type == ShardingType.ROW_WISE: + name = "rw-sharded" + elif sharding_type == ShardingType.COLUMN_WISE: + name = "cw-sharded" + else: + raise Exception(f"Unknown sharding type {sharding_type}") + + if compile_mode == CompileMode.EAGER: + name += "-eager" + elif compile_mode == CompileMode.FX_SCRIPT: + name += "-fxjit" + + return name + + +def init_module_and_run_benchmark( + module: torch.nn.Module, + sharder: ModuleSharder[T], + device: torch.device, + sharding_type: ShardingType, + compile_mode: CompileMode, + world_size: int, + batch_size: int, + warmup_inputs: List[List[KeyedJaggedTensor]], + bench_inputs: List[List[KeyedJaggedTensor]], + prof_inputs: List[List[KeyedJaggedTensor]], + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], + output_dir: str, + num_benchmarks: int, + # pyre-ignore[2] + func_to_benchmark: Any, + benchmark_func_kwargs: Optional[Dict[str, Any]], + rank: int = -1, + queue: Optional[mp.Queue] = None, + pooling_configs: Optional[List[int]] = None, + benchmark_unsharded_module: bool = False, +) -> BenchmarkResult: + """ + There are a couple of caveats here as to why the module has to be initialized + here: + 1. Device. To accurately track memory usage, when sharding modules the initial + placement of the module should be on CPU. This is to avoid double counting + memory allocations and also to prevent CUDA OOMs. + 2. Garbage Collector. Since torch.fx.GraphModule has circular references, + garbage collection us funky and can lead to ooms. Since this frame is + called by the loop through compile modes and sharding types, returning the + benchmark result will mean that the reference to module is lost instead of + existing in the loop + """ + + if rank >= 0: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device.type}:{rank}")) + for warmup_input in warmup_inputs[rank] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device.type}:{rank}")) + for bench_input in bench_inputs[rank] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device.type}:{rank}")) + for prof_input in prof_inputs[rank] + ] + else: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device.type}:0")) + for warmup_input in warmup_inputs[0] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device.type}:0")) + for bench_input in bench_inputs[0] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device.type}:0")) + for prof_input in prof_inputs[0] + ] + + with ( + MultiProcessContext(rank, world_size, "nccl", None) + if rank != -1 + else contextlib.nullcontext() + ) as ctx: + module = transform_module( + module=module, + device=device, + inputs=warmup_inputs_cuda, + sharder=sharder, + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + # pyre-ignore[6] + ctx=ctx, + benchmark_unsharded_module=benchmark_unsharded_module, + ) + + if benchmark_unsharded_module: + name = "unsharded" + compile_mode.name + else: + name = benchmark_type_name(compile_mode, sharding_type) + + res = benchmark( + name, + module, + warmup_inputs_cuda, + bench_inputs_cuda, + prof_inputs_cuda, + world_size=world_size, + output_dir=output_dir, + num_benchmarks=num_benchmarks, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs=benchmark_func_kwargs, + rank=rank, + device_type=device.type, + benchmark_unsharded_module=benchmark_unsharded_module, + ) + + if queue is not None: + queue.put(res) + + while not queue.empty(): + time.sleep(1) + + return res + + +def benchmark_ebc_module( + module: torch.nn.Module, + sharder: ModuleSharder[T], + sharding_types: List[ShardingType], + compile_modes: List[CompileMode], + tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], + warmup_iters: int = 20, + bench_iters: int = 500, + prof_iters: int = 20, + batch_size: int = 2048, + world_size: int = 2, + num_benchmarks: int = 5, + output_dir: str = "", + benchmark_unsharded: bool = False, + func_to_benchmark: Callable[..., None] = default_func_to_benchmark, + benchmark_func_kwargs: Optional[Dict[str, Any]] = None, + pooling_configs: Optional[List[int]] = None, + variable_batch_embeddings: bool = False, + device_type: str = "cuda", +) -> List[BenchmarkResult]: + """ + Benchmark EmbeddingBagCollection (EBC) and QuantEmbeddingBagCollection (QEBC) modules. + + Args: + module: EBC or QEBC module to be benchmarked + sharder: Module sharder for distributing the module + sharding_types: Sharding types to be benchmarked + compile_modes: Compilation modes to be benchmarked + tables: Embedding table configurations + warmup_iters: Number of iterations to run before profiling + bench_iters: Number of iterations to run during profiling + prof_iters: Number of iterations to run after profiling + batch_size: Batch size used in the model + world_size: World size used in distributed training + num_benchmarks: How many times to run over benchmark inputs for statistics + output_dir: Directory to output profiler outputs (traces, stacks) + benchmark_unsharded: Whether to benchmark unsharded version + func_to_benchmark: Custom function to benchmark, check out default_func_to_benchmark for default + benchmark_func_kwargs: Custom keyword arguments to pass to func_to_benchmark + pooling_configs: The pooling factor for the tables (Optional; if not set, we'll use 10 as default) + variable_batch_embeddings: Whether to use variable batch size embeddings + device_type: Device type to use for benchmarking + + Returns: + A list of BenchmarkResults + + Note: + This function is specifically designed for EmbeddingBagCollection (EBC) and + QuantEmbeddingBagCollection (QEBC) modules. It automatically detects the module + type and applies appropriate wrapping and training mode settings. + """ + + # logging.info(f"###### Benchmarking EBC/QEBC Module: {module} ######\n") + logging.info(f"Warmup iterations: {warmup_iters}") + logging.info(f"Benchmark iterations: {bench_iters}") + logging.info(f"Profile iterations: {prof_iters}") + logging.info(f"Batch Size: {batch_size}") + logging.info(f"World Size: {world_size}") + logging.info(f"Number of Benchmarks: {num_benchmarks}") + logging.info(f"Output Directory: {output_dir}") + + assert ( + num_benchmarks > 2 + ), "num_benchmarks needs to be greater than 2 for statistical analysis" + + # Determine training mode based on module type + if isinstance(module, QuantEmbeddingBagCollection) or isinstance( + module, QuantEmbeddingCollection + ): + train = False + else: + train = True + + benchmark_results: List[BenchmarkResult] = [] + + # Wrap the module appropriately based on table type + if isinstance(tables[0], EmbeddingBagConfig): + wrapped_module = EBCWrapper(module) + else: + wrapped_module = ECWrapper(module) + + num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters + inputs = get_inputs( + tables, + batch_size, + world_size, + num_inputs_to_gen, + train, + pooling_configs, + variable_batch_embeddings, + ) + + warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs] + bench_inputs = [ + rank_inputs[warmup_iters : (warmup_iters + bench_iters)] + for rank_inputs in inputs + ] + prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs] + + for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]: + for compile_mode in compile_modes: + if not benchmark_unsharded: + # Test sharders should have a singular sharding_type + sharder._sharding_type = sharding_type.value + # pyre-ignore [6] + benchmark_type = benchmark_type_name(compile_mode, sharding_type) + else: + benchmark_type = "unsharded" + compile_mode.name + + logging.info( + f"\n\n###### Running EBC/QEBC Benchmark Type: {benchmark_type} ######\n" + ) + + if train: + res = multi_process_benchmark( + # pyre-ignore[6] + callable=init_module_and_run_benchmark, + module=wrapped_module, + sharder=sharder, + device=torch.device(device_type), + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + warmup_inputs=warmup_inputs, + bench_inputs=bench_inputs, + prof_inputs=prof_inputs, + tables=tables, + num_benchmarks=num_benchmarks, + output_dir=output_dir, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs=benchmark_func_kwargs, + pooling_configs=pooling_configs, + ) + else: + res = init_module_and_run_benchmark( + module=wrapped_module, + sharder=sharder, + device=torch.device(device_type), + # pyre-ignore + sharding_type=sharding_type, + compile_mode=compile_mode, + world_size=world_size, + batch_size=batch_size, + warmup_inputs=warmup_inputs, + bench_inputs=bench_inputs, + prof_inputs=prof_inputs, + tables=tables, + num_benchmarks=num_benchmarks, + output_dir=output_dir, + func_to_benchmark=func_to_benchmark, + benchmark_func_kwargs=benchmark_func_kwargs, + pooling_configs=pooling_configs, + benchmark_unsharded_module=benchmark_unsharded, + ) + + gc.collect() + + benchmark_results.append(res) + + return benchmark_results diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 09e9ba10b..02d8e962f 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -17,14 +17,16 @@ from typing import List, Tuple import torch +from torchrec.distributed.benchmark.benchmark_ebc import ( + benchmark_ebc_module, + get_tables, +) from torchrec.distributed.benchmark.benchmark_utils import ( - benchmark_module, BenchmarkResult, CompileMode, DLRM_NUM_EMBEDDINGS_PER_FEATURE, EMBEDDING_DIM, - get_tables, init_argparse_and_args, write_report, ) @@ -84,7 +86,7 @@ def benchmark_qec(args: argparse.Namespace, output_dir: str) -> List[BenchmarkRe if not argname.startswith("_") and argname not in IGNORE_ARGNAME } - return benchmark_module( + return benchmark_ebc_module( module=module, sharder=sharder, sharding_types=BENCH_SHARDING_TYPES, @@ -118,7 +120,7 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR if not argname.startswith("_") and argname not in IGNORE_ARGNAME } - return benchmark_module( + return benchmark_ebc_module( module=module, sharder=sharder, sharding_types=BENCH_SHARDING_TYPES, @@ -153,7 +155,7 @@ def benchmark_qec_unsharded( if not argname.startswith("_") and argname not in IGNORE_ARGNAME } - return benchmark_module( + return benchmark_ebc_module( module=module, sharder=sharder, sharding_types=[], @@ -190,7 +192,7 @@ def benchmark_qebc_unsharded( if not argname.startswith("_") and argname not in IGNORE_ARGNAME } - return benchmark_module( + return benchmark_ebc_module( module=module, sharder=sharder, sharding_types=[], diff --git a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py index 9a1fa4647..dae5d8842 100644 --- a/torchrec/distributed/benchmark/benchmark_pipeline_utils.py +++ b/torchrec/distributed/benchmark/benchmark_pipeline_utils.py @@ -16,24 +16,14 @@ 3. Add model-specific params to ModelSelectionConfig and create_model_config's arguments in benchmark_train_pipeline.py """ -import copy from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Type, Union import torch -import torch.distributed as dist -from torch import nn, optim -from torch.optim import Optimizer -from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR -from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner -from torchrec.distributed.planner.types import ParameterConstraints +from torch import nn from torchrec.distributed.test_utils.test_input import ModelInput from torchrec.distributed.test_utils.test_model import ( - TestEBCSharder, TestSparseNN, TestTowerCollectionSparseNN, TestTowerSparseNN, @@ -47,7 +37,6 @@ PrefetchTrainPipelineSparseDist, TrainPipelineSemiSync, ) -from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType from torchrec.models.deepfm import SimpleDeepFMNNWrapper from torchrec.models.dlrm import DLRMWrapper from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -248,55 +237,6 @@ def create_model_config(model_name: str, **kwargs) -> BaseModelConfig: return model_class(**filtered_kwargs) -def generate_tables( - num_unweighted_features: int, - num_weighted_features: int, - embedding_feature_dim: int, -) -> Tuple[ - List[EmbeddingBagConfig], - List[EmbeddingBagConfig], -]: - """ - Generate embedding bag configurations for both unweighted and weighted features. - - This function creates two lists of EmbeddingBagConfig objects: - 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" - 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" - - For both types, the number of embeddings scales with the feature index, - calculated as max(i + 1, 100) * 1000. - - Args: - num_unweighted_features (int): Number of unweighted features to generate. - num_weighted_features (int): Number of weighted features to generate. - embedding_feature_dim (int): Dimension of the embedding vectors. - - Returns: - Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing - two lists - the first for unweighted embedding tables and the second for - weighted embedding tables. - """ - tables = [ - EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, - embedding_dim=embedding_feature_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - ) - for i in range(num_unweighted_features) - ] - weighted_tables = [ - EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, - embedding_dim=embedding_feature_dim, - name="weighted_table_" + str(i), - feature_names=["weighted_feature_" + str(i)], - ) - for i in range(num_weighted_features) - ] - return tables, weighted_tables - - def generate_pipeline( pipeline_type: str, emb_lookup_stream: str, @@ -371,156 +311,6 @@ def generate_pipeline( return Pipeline(model=model, optimizer=opt, device=device, apply_jit=apply_jit) -def generate_planner( - planner_type: str, - topology: Topology, - tables: Optional[List[EmbeddingBagConfig]], - weighted_tables: Optional[List[EmbeddingBagConfig]], - sharding_type: ShardingType, - compute_kernel: EmbeddingComputeKernel, - batch_sizes: List[int], - pooling_factors: Optional[List[float]], - num_poolings: Optional[List[float]], -) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: - """ - Generate an embedding sharding planner based on the specified configuration. - - Args: - planner_type: Type of planner to use ("embedding" or "hetero") - topology: Network topology for distributed training - tables: List of unweighted embedding tables - weighted_tables: List of weighted embedding tables - sharding_type: Strategy for sharding embedding tables - compute_kernel: Compute kernel to use for embedding tables - batch_sizes: Sizes of each batch - pooling_factors: Pooling factors for each feature of the table - num_poolings: Number of poolings for each feature of the table - - Returns: - An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner - - Raises: - RuntimeError: If an unknown planner type is specified - """ - # Create parameter constraints for tables - constraints = {} - num_batches = len(batch_sizes) - - if pooling_factors is None: - pooling_factors = [POOLING_FACTOR] * num_batches - - if num_poolings is None: - num_poolings = [NUM_POOLINGS] * num_batches - - assert ( - len(pooling_factors) == num_batches and len(num_poolings) == num_batches - ), "The length of pooling_factors and num_poolings must match the number of batches." - - if tables is not None: - for table in tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - ) - - if weighted_tables is not None: - for table in weighted_tables: - constraints[table.name] = ParameterConstraints( - sharding_types=[sharding_type.value], - compute_kernels=[compute_kernel.value], - device_group="cuda", - pooling_factors=pooling_factors, - num_poolings=num_poolings, - batch_sizes=batch_sizes, - is_weighted=True, - ) - - if planner_type == "embedding": - return EmbeddingShardingPlanner( - topology=topology, - constraints=constraints if constraints else None, - ) - elif planner_type == "hetero": - topology_groups = {"cuda": topology} - return HeteroEmbeddingShardingPlanner( - topology_groups=topology_groups, - constraints=constraints if constraints else None, - ) - else: - raise RuntimeError(f"Unknown planner type: {planner_type}") - - -def generate_sharded_model_and_optimizer( - model: nn.Module, - sharding_type: str, - kernel_type: str, - pg: dist.ProcessGroup, - device: torch.device, - fused_params: Dict[str, Any], - dense_optimizer: str, - dense_lr: float, - dense_momentum: Optional[float], - dense_weight_decay: Optional[float], - planner: Optional[ - Union[ - EmbeddingShardingPlanner, - HeteroEmbeddingShardingPlanner, - ] - ] = None, -) -> Tuple[nn.Module, Optimizer]: - - sharder = TestEBCSharder( - sharding_type=sharding_type, - kernel_type=kernel_type, - fused_params=fused_params, - ) - sharders = [cast(ModuleSharder[nn.Module], sharder)] - - # Use planner if provided - plan = None - if planner is not None: - if pg is not None: - plan = planner.collective_plan(model, sharders, pg) - else: - plan = planner.plan(model, sharders) - - sharded_model = DistributedModelParallel( - module=copy.deepcopy(model), - env=ShardingEnv.from_process_group(pg), - init_data_parallel=True, - device=device, - sharders=sharders, - plan=plan, - ).to(device) - - # Get dense parameters - dense_params = [ - param - for name, param in sharded_model.named_parameters() - if "sparse" not in name - ] - - # Create optimizer based on the specified type - optimizer_class = getattr(optim, dense_optimizer) - - # Create optimizer with momentum and/or weight_decay if provided - optimizer_kwargs = {"lr": dense_lr} - - if dense_momentum is not None: - optimizer_kwargs["momentum"] = dense_momentum - - if dense_weight_decay is not None: - optimizer_kwargs["weight_decay"] = dense_weight_decay - - optimizer = optimizer_class(dense_params, **optimizer_kwargs) - - return sharded_model, optimizer - - def generate_data( tables: List[EmbeddingBagConfig], weighted_tables: List[EmbeddingBagConfig], diff --git a/torchrec/distributed/benchmark/benchmark_train.py b/torchrec/distributed/benchmark/benchmark_train.py index 15ea780f2..da4b51aa3 100644 --- a/torchrec/distributed/benchmark/benchmark_train.py +++ b/torchrec/distributed/benchmark/benchmark_train.py @@ -17,12 +17,14 @@ from typing import List, Optional, Tuple import torch +from torchrec.distributed.benchmark.benchmark_ebc import ( + benchmark_ebc_module, + get_tables, +) from torchrec.distributed.benchmark.benchmark_utils import ( - benchmark_module, BenchmarkResult, CompileMode, - get_tables, init_argparse_and_args, set_embedding_config, write_report, @@ -106,7 +108,7 @@ def benchmark_ebc( args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings - return benchmark_module( + return benchmark_ebc_module( module=module, sharder=sharder, sharding_types=BENCH_SHARDING_TYPES, diff --git a/torchrec/distributed/benchmark/benchmark_train_pipeline.py b/torchrec/distributed/benchmark/benchmark_train_pipeline.py index 3010ba9d1..80f7fd0ab 100644 --- a/torchrec/distributed/benchmark/benchmark_train_pipeline.py +++ b/torchrec/distributed/benchmark/benchmark_train_pipeline.py @@ -20,8 +20,9 @@ See benchmark_pipeline_utils.py for step-by-step instructions. """ +import importlib from dataclasses import dataclass, field -from typing import Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType @@ -29,21 +30,19 @@ from torchrec.distributed.benchmark.benchmark_pipeline_utils import ( BaseModelConfig, create_model_config, - DeepFMConfig, - DLRMConfig, generate_data, generate_pipeline, - generate_planner, - generate_sharded_model_and_optimizer, - generate_tables, - TestSparseNNConfig, - TestTowerCollectionSparseNNConfig, - TestTowerSparseNNConfig, ) from torchrec.distributed.benchmark.benchmark_utils import ( benchmark_func, + benchmark_module, BenchmarkResult, cmd_conf, + CPUMemoryStats, + generate_planner, + generate_sharded_model_and_optimizer, + generate_tables, + GPUMemoryStats, ) from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel @@ -60,6 +59,18 @@ from torchrec.modules.embedding_configs import EmbeddingBagConfig +@dataclass +class UnifiedBenchmarkConfig: + """Unified configuration for both pipeline and module benchmarking.""" + + benchmark_type: str = "pipeline" # "pipeline" or "module" + + # Module benchmarking specific options + module_path: str = "" # e.g., "torchrec.models.deepfm" + module_class: str = "" # e.g., "SimpleDeepFMNNWrapper" + module_kwargs: Dict[str, Any] = field(default_factory=dict) + + @dataclass class RunOptions: """ @@ -199,63 +210,166 @@ class ModelSelectionConfig: over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1]) -@cmd_conf -def main( - run_option: RunOptions, +def dynamic_import_module(module_path: str, module_class: str) -> Type[nn.Module]: + """Dynamically import a module class from a given path.""" + try: + module = importlib.import_module(module_path) + return getattr(module, module_class) + except (ImportError, AttributeError) as e: + raise RuntimeError(f"Failed to import {module_class} from {module_path}: {e}") + + +def create_module_instance( + unified_config: UnifiedBenchmarkConfig, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], table_config: EmbeddingTablesConfig, - model_selection: ModelSelectionConfig, - pipeline_config: PipelineConfig, - model_config: Optional[BaseModelConfig] = None, -) -> None: +) -> nn.Module: + """Create a module instance based on the unified config.""" + ModuleClass = dynamic_import_module( + unified_config.module_path, unified_config.module_class + ) + + # Handle common module instantiation patterns + if unified_config.module_class == "SimpleDeepFMNNWrapper": + from torchrec.modules.embedding_modules import EmbeddingBagCollection + + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) + return ModuleClass( + embedding_bag_collection=ebc, + num_dense_features=10, # Default value, can be overridden via module_kwargs + **unified_config.module_kwargs, + ) + elif unified_config.module_class == "DLRMWrapper": + from torchrec.modules.embedding_modules import EmbeddingBagCollection + + ebc = EmbeddingBagCollection(tables=tables, device=torch.device("meta")) + return ModuleClass( + embedding_bag_collection=ebc, + dense_in_features=10, # Default value, can be overridden via module_kwargs + dense_arch_layer_sizes=[20, 128], # Default value + over_arch_layer_sizes=[5, 1], # Default value + **unified_config.module_kwargs, + ) + elif unified_config.module_class == "EmbeddingBagCollection": + return ModuleClass(tables=tables, **unified_config.module_kwargs) + else: + # Generic instantiation - try with tables and weighted_tables + try: + return ModuleClass( + tables=tables, + weighted_tables=weighted_tables, + **unified_config.module_kwargs, + ) + except TypeError: + # Fallback to just tables + try: + return ModuleClass(tables=tables, **unified_config.module_kwargs) + except TypeError: + # Fallback to no embedding tables + return ModuleClass(**unified_config.module_kwargs) + + +def run_module_benchmark( + unified_config: UnifiedBenchmarkConfig, + table_config: EmbeddingTablesConfig, + run_option: RunOptions, +) -> BenchmarkResult: + """Run module-level benchmarking.""" tables, weighted_tables = generate_tables( num_unweighted_features=table_config.num_unweighted_features, num_weighted_features=table_config.num_weighted_features, embedding_feature_dim=table_config.embedding_feature_dim, ) - if model_config is None: - model_config = create_model_config( - model_name=model_selection.model_name, - batch_size=model_selection.batch_size, - batch_sizes=model_selection.batch_sizes, - num_float_features=model_selection.num_float_features, - feature_pooling_avg=model_selection.feature_pooling_avg, - use_offsets=model_selection.use_offsets, - dev_str=model_selection.dev_str, - long_kjt_indices=model_selection.long_kjt_indices, - long_kjt_offsets=model_selection.long_kjt_offsets, - long_kjt_lengths=model_selection.long_kjt_lengths, - pin_memory=model_selection.pin_memory, - embedding_groups=model_selection.embedding_groups, - feature_processor_modules=model_selection.feature_processor_modules, - max_feature_lengths=model_selection.max_feature_lengths, - over_arch_clazz=model_selection.over_arch_clazz, - postproc_module=model_selection.postproc_module, - zch=model_selection.zch, - hidden_layer_size=model_selection.hidden_layer_size, - deep_fm_dimension=model_selection.deep_fm_dimension, - dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, - over_arch_layer_sizes=model_selection.over_arch_layer_sizes, - ) + module = create_module_instance( + unified_config, tables, weighted_tables, table_config + ) - # launch trainers - run_multi_process_func( - func=runner, - world_size=run_option.world_size, + return benchmark_module( + module=module, tables=tables, weighted_tables=weighted_tables, - run_option=run_option, - model_config=model_config, - pipeline_config=pipeline_config, + num_float_features=10, # Default value + sharding_type=run_option.sharding_type, + planner_type=run_option.planner_type, + world_size=run_option.world_size, + num_benchmarks=5, # Default value + batch_size=2048, # Default value + compute_kernel=run_option.compute_kernel, + device_type="cuda", ) +@cmd_conf +def main( + run_option: RunOptions, + table_config: EmbeddingTablesConfig, + model_selection: ModelSelectionConfig, + pipeline_config: PipelineConfig, + unified_config: UnifiedBenchmarkConfig, + model_config: Optional[BaseModelConfig] = None, +) -> None: + # Route to appropriate benchmark type based on unified config + if unified_config.benchmark_type == "module": + print("Running module-level benchmark...") + result = run_module_benchmark(unified_config, table_config, run_option) + print(f"Module benchmark completed: {result}") + elif unified_config.benchmark_type == "pipeline": + print("Running pipeline-level benchmark...") + tables, weighted_tables = generate_tables( + num_unweighted_features=table_config.num_unweighted_features, + num_weighted_features=table_config.num_weighted_features, + embedding_feature_dim=table_config.embedding_feature_dim, + ) + + if model_config is None: + model_config = create_model_config( + model_name=model_selection.model_name, + batch_size=model_selection.batch_size, + batch_sizes=model_selection.batch_sizes, + num_float_features=model_selection.num_float_features, + feature_pooling_avg=model_selection.feature_pooling_avg, + use_offsets=model_selection.use_offsets, + dev_str=model_selection.dev_str, + long_kjt_indices=model_selection.long_kjt_indices, + long_kjt_offsets=model_selection.long_kjt_offsets, + long_kjt_lengths=model_selection.long_kjt_lengths, + pin_memory=model_selection.pin_memory, + embedding_groups=model_selection.embedding_groups, + feature_processor_modules=model_selection.feature_processor_modules, + max_feature_lengths=model_selection.max_feature_lengths, + over_arch_clazz=model_selection.over_arch_clazz, + postproc_module=model_selection.postproc_module, + zch=model_selection.zch, + hidden_layer_size=model_selection.hidden_layer_size, + deep_fm_dimension=model_selection.deep_fm_dimension, + dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, + over_arch_layer_sizes=model_selection.over_arch_layer_sizes, + ) + + # launch trainers + run_multi_process_func( + func=runner, + world_size=run_option.world_size, + tables=tables, + weighted_tables=weighted_tables, + run_option=run_option, + model_config=model_config, + pipeline_config=pipeline_config, + ) + else: + raise ValueError( + f"Unknown benchmark_type: {unified_config.benchmark_type}. Must be 'module' or 'pipeline'" + ) + + def run_pipeline( run_option: RunOptions, table_config: EmbeddingTablesConfig, pipeline_config: PipelineConfig, model_config: BaseModelConfig, -) -> List[BenchmarkResult]: +) -> BenchmarkResult: tables, weighted_tables = generate_tables( num_unweighted_features=table_config.num_unweighted_features, @@ -263,7 +377,7 @@ def run_pipeline( embedding_feature_dim=table_config.embedding_feature_dim, ) - return run_multi_process_func( + benchmark_res_per_rank = run_multi_process_func( func=runner, world_size=run_option.world_size, tables=tables, @@ -273,6 +387,28 @@ def run_pipeline( pipeline_config=pipeline_config, ) + # Combine results from all ranks into a single BenchmarkResult + # Use timing data from rank 0, combine memory stats from all ranks + world_size = run_option.world_size + + total_benchmark_res = BenchmarkResult( + short_name=benchmark_res_per_rank[0].short_name, + gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time, + cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time, + gpu_mem_stats=[GPUMemoryStats(rank, 0, 0, 0) for rank in range(world_size)], + cpu_mem_stats=[CPUMemoryStats(rank, 0) for rank in range(world_size)], + rank=0, + ) + + for res in benchmark_res_per_rank: + # Each rank's BenchmarkResult contains 1 GPU and 1 CPU memory measurement + if len(res.gpu_mem_stats) > 0: + total_benchmark_res.gpu_mem_stats[res.rank] = res.gpu_mem_stats[0] + if len(res.cpu_mem_stats) > 0: + total_benchmark_res.cpu_mem_stats[res.rank] = res.cpu_mem_stats[0] + + return total_benchmark_res + def runner( rank: int, diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index a9f7a3864..72af64327 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -12,8 +12,9 @@ import argparse import contextlib + +# Additional imports for the new benchmark_module function import copy -import gc import inspect import json import logging @@ -21,51 +22,41 @@ import resource import time import timeit -from dataclasses import dataclass, fields, is_dataclass, MISSING +from dataclasses import dataclass, field, fields, is_dataclass, MISSING from enum import Enum from typing import ( Any, Callable, - ContextManager, + cast, Dict, get_args, get_origin, List, Optional, - Set, Tuple, TypeVar, Union, ) -import click - import torch +import torch.distributed as dist import yaml -from torch import multiprocessing as mp +from fbgemm_gpu.split_embedding_configs import EmbOptimType +from torch import multiprocessing as mp, nn, optim from torch.autograd.profiler import record_function +from torch.optim import Optimizer from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import ShardingType -from torchrec.distributed.global_settings import set_propogate_device - +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology -from torchrec.distributed.planner.enumerators import EmbeddingEnumerator -from torchrec.distributed.planner.shard_estimators import ( - EmbeddingPerfEstimator, - EmbeddingStorageEstimator, -) -from torchrec.distributed.shard import _shard_modules +from torchrec.distributed.planner.constants import NUM_POOLINGS, POOLING_FACTOR +from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner +from torchrec.distributed.planner.types import ParameterConstraints from torchrec.distributed.test_utils.multi_process import MultiProcessContext -from torchrec.distributed.test_utils.test_model import ModelInput - -from torchrec.distributed.types import DataType, ModuleSharder, ShardingEnv -from torchrec.fx import symbolic_trace -from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig -from torchrec.quant.embedding_modules import ( - EmbeddingBagCollection as QuantEmbeddingBagCollection, - EmbeddingCollection as QuantEmbeddingCollection, -) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.test_utils.test_model import TestEBCSharder +from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.test_utils import get_free_port logger: logging.Logger = logging.getLogger() @@ -149,6 +140,25 @@ def __str__(self) -> str: return f"Rank {self.rank}: CPU Memory Peak RSS: {self.peak_rss_mbs/1000:.2f} GB" +@dataclass +class ModuleBenchmarkConfig: + """Configuration for module-level benchmarking.""" + + module_path: str = "" # e.g., "torchrec.models.deepfm" + module_class: str = "" # e.g., "SimpleDeepFMNNWrapper" + module_kwargs: Dict[str, Any] = field( + default_factory=dict + ) # Additional kwargs for module instantiation + num_float_features: int = 0 + sharding_type: ShardingType = ShardingType.TABLE_WISE + planner_type: str = "embedding" + world_size: int = 2 + num_benchmarks: int = 5 + batch_size: int = 2048 + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED + device_type: str = "cuda" + + @dataclass class BenchmarkResult: "Class for holding results of benchmark runs" @@ -243,216 +253,9 @@ def cpu_mem_percentile( ) -class ECWrapper(torch.nn.Module): - """ - Wrapper Module for benchmarking EC Modules - - Args: - module: module to benchmark - - Call Args: - input: KeyedJaggedTensor KJT input to module - - Returns: - output: KT output from module - - Example: - e1_config = EmbeddingConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] - ) - e2_config = EmbeddingConfig( - name="t2", embedding_dim=3, num_embeddings=10, feature_names=["f2"] - ) - - ec = EmbeddingCollection(tables=[e1_config, e2_config]) - - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - - ec.qconfig = torch.quantization.QConfig( - activation=torch.quantization.PlaceholderObserver.with_args( - dtype=torch.qint8 - ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), - ) - - qec = QuantEmbeddingCollection.from_float(ecc) - - wrapped_module = ECWrapper(qec) - quantized_embeddings = wrapped_module(features) - """ - - def __init__(self, module: torch.nn.Module) -> None: - super().__init__() - self._module = module - - def forward(self, input: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: - """ - Args: - input (KeyedJaggedTensor): KJT of form [F X B X L]. - - Returns: - Dict[str, JaggedTensor] - """ - return self._module.forward(input) - - -class EBCWrapper(torch.nn.Module): - """ - Wrapper Module for benchmarking Modules - - Args: - module: module to benchmark - - Call Args: - input: KeyedJaggedTensor KJT input to module - - Returns: - output: KT output from module - - Example: - table_0 = EmbeddingBagConfig( - name="t1", embedding_dim=3, num_embeddings=10, feature_names=["f1"] - ) - table_1 = EmbeddingBagConfig( - name="t2", embedding_dim=4, num_embeddings=10, feature_names=["f2"] - ) - ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config]) - - features = KeyedJaggedTensor( - keys=["f1", "f2"], - values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 8]), - ) - - ebc.qconfig = torch.quantization.QConfig( - activation=torch.quantization.PlaceholderObserver.with_args( - dtype=torch.qint8 - ), - weight=torch.quantization.PlaceholderObserver.with_args(dtype=torch.qint8), - ) - - qebc = QuantEmbeddingBagCollection.from_float(ebc) - - wrapped_module = EBCWrapper(qebc) - quantized_embeddings = wrapped_module(features) - """ - - def __init__(self, module: torch.nn.Module) -> None: - super().__init__() - self._module = module - - def forward(self, input: KeyedJaggedTensor) -> KeyedTensor: - """ - Args: - input (KeyedJaggedTensor): KJT of form [F X B X L]. - - Returns: - KeyedTensor - """ - return self._module.forward(input) - - T = TypeVar("T", bound=torch.nn.Module) -def default_func_to_benchmark( - model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] -) -> None: - with torch.inference_mode(): - for bench_input in bench_inputs: - model(bench_input) - - -def get_tables( - table_sizes: List[Tuple[int, int]], - is_pooled: bool = True, - data_type: DataType = DataType.INT8, -) -> Union[List[EmbeddingBagConfig], List[EmbeddingConfig]]: - if is_pooled: - tables: List[EmbeddingBagConfig] = [ - EmbeddingBagConfig( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - data_type=data_type, - ) - for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) - ] - else: - tables: List[EmbeddingConfig] = [ - EmbeddingConfig( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - data_type=data_type, - ) - for i, (num_embeddings, embedding_dim) in enumerate(table_sizes) - ] - - return tables - - -def get_inputs( - tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], - batch_size: int, - world_size: int, - num_inputs: int, - train: bool, - pooling_configs: Optional[List[int]] = None, - variable_batch_embeddings: bool = False, -) -> List[List[KeyedJaggedTensor]]: - inputs_batch: List[List[KeyedJaggedTensor]] = [] - - if variable_batch_embeddings and not train: - raise RuntimeError("Variable batch size is only supported in training mode") - - for _ in range(num_inputs): - if variable_batch_embeddings: - _, model_input_by_rank = ModelInput.generate_variable_batch_input( - average_batch_size=batch_size, - world_size=world_size, - num_float_features=0, - tables=tables, - ) - else: - _, model_input_by_rank = ModelInput.generate( - batch_size=batch_size, - world_size=world_size, - num_float_features=0, - tables=tables, - weighted_tables=[], - tables_pooling=pooling_configs, - indices_dtype=torch.int32, - lengths_dtype=torch.int32, - ) - - if train: - sparse_features_by_rank = [ - model_input.idlist_features - for model_input in model_input_by_rank - if isinstance(model_input.idlist_features, KeyedJaggedTensor) - ] - inputs_batch.append(sparse_features_by_rank) - else: - sparse_features = model_input_by_rank[0].idlist_features - assert isinstance(sparse_features, KeyedJaggedTensor) - inputs_batch.append([sparse_features]) - - # Transpose if train, as inputs_by_rank is currently in [B X R] format - inputs_by_rank = [ - [sparse_features for sparse_features in sparse_features_rank] - for sparse_features_rank in zip(*inputs_batch) - ] - - return inputs_by_rank - - def write_report( benchmark_results: List[BenchmarkResult], report_file: str, @@ -491,6 +294,70 @@ def write_report( logger.info(f"Report written to {report_file}:\n{report_str}") +def multi_process_benchmark( + callable: Callable[ + ..., + None, + ], + # pyre-ignore + **kwargs, +) -> BenchmarkResult: + + def setUp() -> None: + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + + assert "world_size" in kwargs + world_size = kwargs["world_size"] + + setUp() + benchmark_res_per_rank = [] + # kineto has a known problem with fork-server: it'll hang + # when dumping the trace. Workaround with spawn + ctx = mp.get_context("spawn") + qq = ctx.SimpleQueue() + processes = [] + + for rank in range(world_size): + kwargs["rank"] = rank + kwargs["world_size"] = world_size + kwargs["queue"] = qq + p = ctx.Process( + target=callable, + kwargs=kwargs, + ) + p.start() + processes.append(p) + + for _ in range(world_size): + res = qq.get() + + benchmark_res_per_rank.append(res) + assert len(res.gpu_mem_stats) == 1 + assert len(res.cpu_mem_stats) == 1 + + for p in processes: + p.join() + assert 0 == p.exitcode + + total_benchmark_res = BenchmarkResult( + short_name=benchmark_res_per_rank[0].short_name, + gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time, + cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time, + gpu_mem_stats=[GPUMemoryStats(rank, 0, 0, 0) for rank in range(world_size)], + cpu_mem_stats=[CPUMemoryStats(rank, 0) for rank in range(world_size)], + rank=0, + ) + + for res in benchmark_res_per_rank: + # Each rank's BenchmarkResult contains 1 GPU and 1 CPU memory measurement + total_benchmark_res.gpu_mem_stats[res.rank] = res.gpu_mem_stats[0] + total_benchmark_res.cpu_mem_stats[res.rank] = res.cpu_mem_stats[0] + + return total_benchmark_res + + def set_embedding_config( embedding_config_json: str, ) -> Tuple[List[Tuple[int, int]], List[int]]: @@ -533,6 +400,473 @@ def set_embedding_config( return embedding_configs, pooling_configs +def generate_tables( + num_unweighted_features: int = 100, + num_weighted_features: int = 100, + embedding_feature_dim: int = 128, +) -> Tuple[ + List[EmbeddingBagConfig], + List[EmbeddingBagConfig], +]: + """ + Generate embedding bag configurations for both unweighted and weighted features. + + This function creates two lists of EmbeddingBagConfig objects: + 1. Unweighted tables: Named as "table_{i}" with feature names "feature_{i}" + 2. Weighted tables: Named as "weighted_table_{i}" with feature names "weighted_feature_{i}" + + For both types, the number of embeddings scales with the feature index, + calculated as max(i + 1, 100) * 1000. + + Args: + num_unweighted_features (int): Number of unweighted features to generate. + num_weighted_features (int): Number of weighted features to generate. + embedding_feature_dim (int): Dimension of the embedding vectors. + + Returns: + Tuple[List[EmbeddingBagConfig], List[EmbeddingBagConfig]]: A tuple containing + two lists - the first for unweighted embedding tables and the second for + weighted embedding tables. + """ + tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=embedding_feature_dim, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_unweighted_features) + ] + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=max(i + 1, 100) * 1000, + embedding_dim=embedding_feature_dim, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + return tables, weighted_tables + + +def generate_planner( + planner_type: str, + topology: Topology, + tables: Optional[List[EmbeddingBagConfig]], + weighted_tables: Optional[List[EmbeddingBagConfig]], + sharding_type: ShardingType, + compute_kernel: EmbeddingComputeKernel, + batch_sizes: List[int], + pooling_factors: Optional[List[float]] = None, + num_poolings: Optional[List[float]] = None, +) -> Union[EmbeddingShardingPlanner, HeteroEmbeddingShardingPlanner]: + """ + Generate an embedding sharding planner based on the specified configuration. + + Args: + planner_type: Type of planner to use ("embedding" or "hetero") + topology: Network topology for distributed training + tables: List of unweighted embedding tables + weighted_tables: List of weighted embedding tables + sharding_type: Strategy for sharding embedding tables + compute_kernel: Compute kernel to use for embedding tables + batch_sizes: Sizes of each batch + pooling_factors: Pooling factors for each feature of the table + num_poolings: Number of poolings for each feature of the table + + Returns: + An instance of EmbeddingShardingPlanner or HeteroEmbeddingShardingPlanner + + Raises: + RuntimeError: If an unknown planner type is specified + """ + # Create parameter constraints for tables + constraints = {} + num_batches = len(batch_sizes) + + if pooling_factors is None: + pooling_factors = [POOLING_FACTOR] * num_batches + + if num_poolings is None: + num_poolings = [NUM_POOLINGS] * num_batches + + assert ( + len(pooling_factors) == num_batches and len(num_poolings) == num_batches + ), "The length of pooling_factors and num_poolings must match the number of batches." + + if tables is not None: + for table in tables: + constraints[table.name] = ParameterConstraints( + sharding_types=[sharding_type.value], + compute_kernels=[compute_kernel.value], + device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + ) + + if weighted_tables is not None: + for table in weighted_tables: + constraints[table.name] = ParameterConstraints( + sharding_types=[sharding_type.value], + compute_kernels=[compute_kernel.value], + device_group="cuda", + pooling_factors=pooling_factors, + num_poolings=num_poolings, + batch_sizes=batch_sizes, + is_weighted=True, + ) + + if planner_type == "embedding": + return EmbeddingShardingPlanner( + topology=topology, + constraints=constraints if constraints else None, + ) + elif planner_type == "hetero": + topology_groups = {"cuda": topology} + return HeteroEmbeddingShardingPlanner( + topology_groups=topology_groups, + constraints=constraints if constraints else None, + ) + else: + raise RuntimeError(f"Unknown planner type: {planner_type}") + + +def generate_sharded_model_and_optimizer( + model: nn.Module, + sharding_type: str, + kernel_type: str, + pg: dist.ProcessGroup, + device: torch.device, + fused_params: Dict[str, Any], + dense_optimizer: str = "SGD", + dense_lr: float = 0.1, + dense_momentum: Optional[float] = None, + dense_weight_decay: Optional[float] = None, + planner: Optional[ + Union[ + EmbeddingShardingPlanner, + HeteroEmbeddingShardingPlanner, + ] + ] = None, +) -> Tuple[nn.Module, Optimizer]: + """ + Generate a sharded model and optimizer for distributed training. + + Args: + model: The model to be sharded + sharding_type: Type of sharding strategy + kernel_type: Type of compute kernel + pg: Process group for distributed training + device: Device to place the model on + fused_params: Parameters for the fused optimizer + dense_optimizer: Optimizer type for dense parameters + dense_lr: Learning rate for dense parameters + dense_momentum: Momentum for dense parameters (optional) + dense_weight_decay: Weight decay for dense parameters (optional) + planner: Optional planner for sharding strategy + + Returns: + Tuple of sharded model and optimizer + """ + sharder = TestEBCSharder( + sharding_type=sharding_type, + kernel_type=kernel_type, + fused_params=fused_params, + ) + sharders = [cast(ModuleSharder[nn.Module], sharder)] + + # Use planner if provided + plan = None + if planner is not None: + if pg is not None: + plan = planner.collective_plan(model, sharders, pg) + else: + plan = planner.plan(model, sharders) + + sharded_model = DistributedModelParallel( + module=copy.deepcopy(model), + env=ShardingEnv.from_process_group(pg), + init_data_parallel=True, + device=device, + sharders=sharders, + plan=plan, + ).to(device) + + # Get dense parameters + dense_params = [ + param + for name, param in sharded_model.named_parameters() + if "sparse" not in name + ] + + # Create optimizer based on the specified type + optimizer_class = getattr(optim, dense_optimizer) + + # Create optimizer with momentum and/or weight_decay if provided + optimizer_kwargs = {"lr": dense_lr} + + if dense_momentum is not None: + optimizer_kwargs["momentum"] = dense_momentum + + if dense_weight_decay is not None: + optimizer_kwargs["weight_decay"] = dense_weight_decay + + optimizer = optimizer_class(dense_params, **optimizer_kwargs) + + return sharded_model, optimizer + + +def _init_module_and_run_benchmark( + module: torch.nn.Module, + sharding_type: ShardingType, + planner_type: str, + compute_kernel: EmbeddingComputeKernel, + tables: List[EmbeddingBagConfig], + weighted_tables: List[EmbeddingBagConfig], + batch_size: int, + num_benchmarks: int, + world_size: int, + num_float_features: int = 0, + rank: int = -1, + queue: Optional[mp.Queue] = None, + device_type: str = "cuda", + warmup_iters: int = 20, + bench_iters: int = 100, + prof_iters: int = 20, +) -> None: + """ + Initialize module and run benchmark for a single process. + + This is a simplified version of init_module_and_run_benchmark from benchmark_ebc.py + that doesn't handle compile modes and focuses on the core benchmarking functionality. + """ + from torchrec.distributed.comm import get_local_size + + # Generate input data + num_inputs_to_gen = warmup_iters + bench_iters + prof_iters + + batch_sizes = [batch_size] * num_inputs_to_gen + inputs_batch = [] + + for _ in range(num_inputs_to_gen): + model_input_by_rank = [] + for _ in range(world_size): + model_input_by_rank.append( + ModelInput.generate( + batch_size=batch_size, + num_float_features=num_float_features, + tables=tables, + weighted_tables=weighted_tables, + indices_dtype=torch.int32, + lengths_dtype=torch.int32, + ) + ) + + inputs_batch.append(model_input_by_rank) + + # Transpose to get inputs by rank: [R x B] format + inputs_by_rank = list(zip(*inputs_batch)) + + if rank >= 0: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device_type}:{rank}")) + for warmup_input in inputs_by_rank[rank][:warmup_iters] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device_type}:{rank}")) + for bench_input in inputs_by_rank[rank][ + warmup_iters : warmup_iters + bench_iters + ] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device_type}:{rank}")) + for prof_input in inputs_by_rank[rank][-prof_iters:] + ] + else: + warmup_inputs_cuda = [ + warmup_input.to(torch.device(f"{device_type}:0")) + for warmup_input in inputs_by_rank[0][:warmup_iters] + ] + bench_inputs_cuda = [ + bench_input.to(torch.device(f"{device_type}:0")) + for bench_input in inputs_by_rank[0][ + warmup_iters : warmup_iters + bench_iters + ] + ] + prof_inputs_cuda = [ + prof_input.to(torch.device(f"{device_type}:0")) + for prof_input in inputs_by_rank[0][-prof_iters:] + ] + + with ( + MultiProcessContext( + rank, world_size, "nccl", use_deterministic_algorithms=False + ) + if rank != -1 + else contextlib.nullcontext() + ) as ctx: + # Create topology and planner + topology = Topology( + local_world_size=get_local_size(world_size), + world_size=world_size, + compute_device=device_type, + ) + + planner = generate_planner( + planner_type=planner_type, + topology=topology, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=sharding_type, + compute_kernel=compute_kernel, + batch_sizes=batch_sizes[ + :num_benchmarks + ], # Use only benchmark batches for planning + ) + + # Prepare fused_params for sparse optimizer + fused_params = { + "optimizer": EmbOptimType.EXACT_ADAGRAD, + "learning_rate": 0.1, + } + + device = ctx.device if rank != -1 else torch.device(device_type) + pg = ctx.pg if rank != -1 else None + + sharded_model, _ = generate_sharded_model_and_optimizer( + model=module, + sharding_type=sharding_type.value, + kernel_type=compute_kernel.value, + pg=pg, + device=device, + fused_params=fused_params, + planner=planner, + ) + + def _func_to_benchmark( + model: torch.nn.Module, bench_inputs: List[KeyedJaggedTensor] + ) -> None: + with torch.inference_mode(): + for bench_input in bench_inputs: + model(bench_input) + + name = f"{sharding_type.value}-{planner_type}" + + res = benchmark( + name, + sharded_model, + warmup_inputs_cuda, + bench_inputs_cuda, + prof_inputs_cuda, + world_size=world_size, + output_dir="", + num_benchmarks=num_benchmarks, + func_to_benchmark=_func_to_benchmark, + benchmark_func_kwargs=None, + rank=rank, + device_type=device_type, + benchmark_unsharded_module=False, + ) + + if queue is not None: + queue.put(res) + + +def benchmark_module( + module: torch.nn.Module, + tables: List[EmbeddingBagConfig], + weighted_tables: Optional[List[EmbeddingBagConfig]] = None, + num_float_features: int = 0, + sharding_type: ShardingType = ShardingType.TABLE_WISE, + planner_type: str = "embedding", + world_size: int = 2, + num_benchmarks: int = 5, + batch_size: int = 2048, + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.FUSED, + device_type: str = "cuda", +) -> BenchmarkResult: + """ + Benchmark any PyTorch module with distributed sharding. + + This function provides a simple interface to benchmark arbitrary PyTorch modules + using TorchRec's distributed sharding capabilities. It uses the provided embedding + tables to generate input data, sets up multiprocessing for distributed training, + and returns comprehensive benchmark results. + + Args: + module: PyTorch module to benchmark + tables: List of unweighted embedding table configurations + weighted_tables: Optional list of weighted embedding table configurations + sharding_type: Strategy for sharding embedding tables across devices + planner_type: Type of planner to use ("embedding" or "hetero") + world_size: Number of processes/GPUs to use for distributed training + num_benchmarks: Number of iterations to run for statistical analysis + batch_size: Batch size to use for benchmarking + compute_kernel: Compute kernel to use for embedding tables + device_type: Device type to use ("cuda" or "cpu") + + Returns: + BenchmarkResult containing timing and memory statistics + + Example: + from torchrec.modules.embedding_modules import EmbeddingBagCollection + from torchrec.modules.embedding_configs import EmbeddingBagConfig + + # Create embedding tables + tables = [ + EmbeddingBagConfig( + name="table_0", embedding_dim=128, num_embeddings=100000, + feature_names=["feature_0"] + ) + ] + + # Create a simple EBC module + ebc = EmbeddingBagCollection(tables=tables) + + # Benchmark it + result = benchmark_module( + module=ebc, + tables=tables, + world_size=2, + num_benchmarks=10 + ) + print(result) + """ + logger.info(f"Starting benchmark for module: {type(module).__name__}") + logger.info(f"Sharding type: {sharding_type}") + logger.info(f"Planner type: {planner_type}") + logger.info(f"World size: {world_size}") + logger.info(f"Batch size: {batch_size}") + logger.info(f"Number of benchmarks: {num_benchmarks}") + + assert ( + num_benchmarks > 2 + ), "num_benchmarks needs to be greater than 2 for statistical analysis" + + # Use provided tables or default to empty list for weighted tables + if weighted_tables is None: + weighted_tables = [] + + # Use multiprocessing for distributed benchmarking (always assume train mode) + res = multi_process_benchmark( + callable=_init_module_and_run_benchmark, + module=module, + sharding_type=sharding_type, + planner_type=planner_type, + compute_kernel=compute_kernel, + tables=tables, + weighted_tables=weighted_tables, + batch_size=batch_size, + num_benchmarks=num_benchmarks, + world_size=world_size, + num_float_features=num_float_features, + device_type=device_type, + ) + + return res + + # pyre-ignore [24] def cmd_conf(func: Callable) -> Callable: @@ -685,89 +1019,6 @@ def init_argparse_and_args() -> argparse.Namespace: return args -def transform_module( - module: torch.nn.Module, - device: torch.device, - inputs: List[KeyedJaggedTensor], - sharder: ModuleSharder[T], - sharding_type: ShardingType, - compile_mode: CompileMode, - world_size: int, - batch_size: int, - # pyre-fixme[24]: Generic type `ContextManager` expects 1 type parameter. - ctx: ContextManager, - benchmark_unsharded_module: bool = False, -) -> torch.nn.Module: - def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: - eager_module(inputs[0]) - graph_module = symbolic_trace( - eager_module, leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"] - ) - scripted_module = torch.jit.script(graph_module) - return scripted_module - - set_propogate_device(True) - - sharded_module = None - - if not benchmark_unsharded_module: - topology: Topology = Topology(world_size=world_size, compute_device=device.type) - planner = EmbeddingShardingPlanner( - topology=topology, - batch_size=batch_size, - enumerator=EmbeddingEnumerator( - topology=topology, - batch_size=batch_size, - estimator=[ - EmbeddingPerfEstimator(topology=topology), - EmbeddingStorageEstimator(topology=topology), - ], - ), - ) - - # Don't want to modify the module outright - # Since module is on cpu, won't cause cuda oom. - copied_module = copy.deepcopy(module) - # pyre-ignore [6] - plan = planner.plan(copied_module, [sharder]) - - if isinstance(ctx, MultiProcessContext): - sharded_module = DistributedModelParallel( - copied_module, - # pyre-ignore[6] - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - # pyre-ignore[6] - sharders=[sharder], - device=ctx.device, - ) - else: - env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) - - sharded_module = _shard_modules( - module=copied_module, - # pyre-fixme[6]: For 2nd argument expected - # `Optional[List[ModuleSharder[Module]]]` but got - # `List[ModuleSharder[Variable[T (bound to Module)]]]`. - sharders=[sharder], - device=device, - plan=plan, - env=env, - ) - - if compile_mode == CompileMode.FX_SCRIPT: - return fx_script_module( - # pyre-fixme[6]: For 1st argument expected `Module` but got - # `Optional[Module]`. - sharded_module - if not benchmark_unsharded_module - else module - ) - else: - # pyre-fixme[7]: Expected `Module` but got `Optional[Module]`. - return sharded_module if not benchmark_unsharded_module else module - - def _run_benchmark_core( name: str, run_iter_fn: Callable[[], None], @@ -1030,345 +1281,3 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None: export_stacks=export_stacks, reset_accumulated_memory_stats=True, ) - - -def benchmark_type_name(compile_mode: CompileMode, sharding_type: ShardingType) -> str: - if sharding_type == ShardingType.TABLE_WISE: - name = "tw-sharded" - elif sharding_type == ShardingType.ROW_WISE: - name = "rw-sharded" - elif sharding_type == ShardingType.COLUMN_WISE: - name = "cw-sharded" - else: - raise Exception(f"Unknown sharding type {sharding_type}") - - if compile_mode == CompileMode.EAGER: - name += "-eager" - elif compile_mode == CompileMode.FX_SCRIPT: - name += "-fxjit" - - return name - - -def init_module_and_run_benchmark( - module: torch.nn.Module, - sharder: ModuleSharder[T], - device: torch.device, - sharding_type: ShardingType, - compile_mode: CompileMode, - world_size: int, - batch_size: int, - warmup_inputs: List[List[KeyedJaggedTensor]], - bench_inputs: List[List[KeyedJaggedTensor]], - prof_inputs: List[List[KeyedJaggedTensor]], - tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], - output_dir: str, - num_benchmarks: int, - # pyre-ignore[2] - func_to_benchmark: Any, - benchmark_func_kwargs: Optional[Dict[str, Any]], - rank: int = -1, - queue: Optional[mp.Queue] = None, - pooling_configs: Optional[List[int]] = None, - benchmark_unsharded_module: bool = False, -) -> BenchmarkResult: - """ - There are a couple of caveats here as to why the module has to be initialized - here: - 1. Device. To accurately track memory usage, when sharding modules the initial - placement of the module should be on CPU. This is to avoid double counting - memory allocations and also to prevent CUDA OOMs. - 2. Garbage Collector. Since torch.fx.GraphModule has circular references, - garbage collection us funky and can lead to ooms. Since this frame is - called by the loop through compile modes and sharding types, returning the - benchmark result will mean that the reference to module is lost instead of - existing in the loop - """ - - if rank >= 0: - warmup_inputs_cuda = [ - warmup_input.to(torch.device(f"{device.type}:{rank}")) - for warmup_input in warmup_inputs[rank] - ] - bench_inputs_cuda = [ - bench_input.to(torch.device(f"{device.type}:{rank}")) - for bench_input in bench_inputs[rank] - ] - prof_inputs_cuda = [ - prof_input.to(torch.device(f"{device.type}:{rank}")) - for prof_input in prof_inputs[rank] - ] - else: - warmup_inputs_cuda = [ - warmup_input.to(torch.device(f"{device.type}:0")) - for warmup_input in warmup_inputs[0] - ] - bench_inputs_cuda = [ - bench_input.to(torch.device(f"{device.type}:0")) - for bench_input in bench_inputs[0] - ] - prof_inputs_cuda = [ - prof_input.to(torch.device(f"{device.type}:0")) - for prof_input in prof_inputs[0] - ] - - with ( - MultiProcessContext(rank, world_size, "nccl", None) - if rank != -1 - else contextlib.nullcontext() - ) as ctx: - module = transform_module( - module=module, - device=device, - inputs=warmup_inputs_cuda, - sharder=sharder, - sharding_type=sharding_type, - compile_mode=compile_mode, - world_size=world_size, - batch_size=batch_size, - # pyre-ignore[6] - ctx=ctx, - benchmark_unsharded_module=benchmark_unsharded_module, - ) - - if benchmark_unsharded_module: - name = "unsharded" + compile_mode.name - else: - name = benchmark_type_name(compile_mode, sharding_type) - - res = benchmark( - name, - module, - warmup_inputs_cuda, - bench_inputs_cuda, - prof_inputs_cuda, - world_size=world_size, - output_dir=output_dir, - num_benchmarks=num_benchmarks, - func_to_benchmark=func_to_benchmark, - benchmark_func_kwargs=benchmark_func_kwargs, - rank=rank, - device_type=device.type, - benchmark_unsharded_module=benchmark_unsharded_module, - ) - - if queue is not None: - queue.put(res) - - while not queue.empty(): - time.sleep(1) - - return res - - -def multi_process_benchmark( - callable: Callable[ - ..., - None, - ], - # pyre-ignore - **kwargs, -) -> BenchmarkResult: - - def setUp() -> None: - if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"] = str("localhost") - os.environ["MASTER_PORT"] = str(get_free_port()) - - assert "world_size" in kwargs - world_size = kwargs["world_size"] - - setUp() - benchmark_res_per_rank = [] - # kineto has a known problem with fork-server: it'll hang - # when dumping the trace. Workaround with spawn - ctx = mp.get_context("spawn") - qq = ctx.SimpleQueue() - processes = [] - - for rank in range(world_size): - kwargs["rank"] = rank - kwargs["world_size"] = world_size - kwargs["queue"] = qq - p = ctx.Process( - target=callable, - kwargs=kwargs, - ) - p.start() - processes.append(p) - - for _ in range(world_size): - res = qq.get() - - benchmark_res_per_rank.append(res) - assert len(res.gpu_mem_stats) == 1 - assert len(res.cpu_mem_stats) == 1 - - for p in processes: - p.join() - assert 0 == p.exitcode - - total_benchmark_res = BenchmarkResult( - short_name=benchmark_res_per_rank[0].short_name, - gpu_elapsed_time=benchmark_res_per_rank[0].gpu_elapsed_time, - cpu_elapsed_time=benchmark_res_per_rank[0].cpu_elapsed_time, - gpu_mem_stats=[GPUMemoryStats(rank, 0, 0, 0) for rank in range(world_size)], - cpu_mem_stats=[CPUMemoryStats(rank, 0) for rank in range(world_size)], - rank=0, - ) - - for res in benchmark_res_per_rank: - # Each rank's BenchmarkResult contains 1 GPU and 1 CPU memory measurement - total_benchmark_res.gpu_mem_stats[res.rank] = res.gpu_mem_stats[0] - total_benchmark_res.cpu_mem_stats[res.rank] = res.cpu_mem_stats[0] - - return total_benchmark_res - - -def benchmark_module( - module: torch.nn.Module, - sharder: ModuleSharder[T], - sharding_types: List[ShardingType], - compile_modes: List[CompileMode], - tables: Union[List[EmbeddingBagConfig], List[EmbeddingConfig]], - warmup_iters: int = 20, - bench_iters: int = 500, - prof_iters: int = 20, - batch_size: int = 2048, - world_size: int = 2, - num_benchmarks: int = 5, - output_dir: str = "", - benchmark_unsharded: bool = False, - func_to_benchmark: Callable[..., None] = default_func_to_benchmark, - benchmark_func_kwargs: Optional[Dict[str, Any]] = None, - pooling_configs: Optional[List[int]] = None, - variable_batch_embeddings: bool = False, - device_type: str = "cuda", -) -> List[BenchmarkResult]: - """ - Args: - eager_module: Eager mode module to be benchmarked - sharding_types: Sharding types to be benchmarked - compile_modes: Compilation modes to be benchmarked - warmup_iters: Number of iterations to run before profiling - bench_iters: Number of iterations to run during profiling - prof_iters: Number of iterations to run after profiling - batch_size: Batch size used in the model - world_size: World size used in the - num_benchmarks: How many times to run over benchmark inputs for statistics - output_dir: Directory to output profiler outputs (traces, stacks) - pooling_configs: The pooling factor for the tables. - (Optional; if not set, we'll use 10 as default) - func_to_benchmark: Custom function to benchmark, check out default_func_to_benchmark for default - benchmark_func_kwargs: Custom keyword arguments to pass to func_to_benchmark - - Returns: - A list of BenchmarkResults - """ - - # logging.info(f"###### Benchmarking Module: {eager_module} ######\n") - logging.info(f"Warmup iterations: {warmup_iters}") - logging.info(f"Benchmark iterations: {bench_iters}") - logging.info(f"Profile iterations: {prof_iters}") - logging.info(f"Batch Size: {batch_size}") - logging.info(f"World Size: {world_size}") - logging.info(f"Number of Benchmarks: {num_benchmarks}") - logging.info(f"Output Directory: {output_dir}") - - assert ( - num_benchmarks > 2 - ), "num_benchmarks needs to be greater than 2 for statistical analysis" - if isinstance(module, QuantEmbeddingBagCollection) or isinstance( - module, QuantEmbeddingCollection - ): - train = False - else: - train = True - - benchmark_results: List[BenchmarkResult] = [] - - if isinstance(tables[0], EmbeddingBagConfig): - wrapped_module = EBCWrapper(module) - else: - wrapped_module = ECWrapper(module) - - num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters - inputs = get_inputs( - tables, - batch_size, - world_size, - num_inputs_to_gen, - train, - pooling_configs, - variable_batch_embeddings, - ) - - warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs] - bench_inputs = [ - rank_inputs[warmup_iters : (warmup_iters + bench_iters)] - for rank_inputs in inputs - ] - prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs] - - for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]: - for compile_mode in compile_modes: - if not benchmark_unsharded: - # Test sharders should have a singular sharding_type - sharder._sharding_type = sharding_type.value - # pyre-ignore [6] - benchmark_type = benchmark_type_name(compile_mode, sharding_type) - else: - benchmark_type = "unsharded" + compile_mode.name - - logging.info( - f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n" - ) - - if train: - res = multi_process_benchmark( - # pyre-ignore[6] - callable=init_module_and_run_benchmark, - module=wrapped_module, - sharder=sharder, - device=torch.device(device_type), - sharding_type=sharding_type, - compile_mode=compile_mode, - world_size=world_size, - batch_size=batch_size, - warmup_inputs=warmup_inputs, - bench_inputs=bench_inputs, - prof_inputs=prof_inputs, - tables=tables, - num_benchmarks=num_benchmarks, - output_dir=output_dir, - func_to_benchmark=func_to_benchmark, - benchmark_func_kwargs=benchmark_func_kwargs, - pooling_configs=pooling_configs, - ) - else: - res = init_module_and_run_benchmark( - module=wrapped_module, - sharder=sharder, - device=torch.device(device_type), - # pyre-ignore - sharding_type=sharding_type, - compile_mode=compile_mode, - world_size=world_size, - batch_size=batch_size, - warmup_inputs=warmup_inputs, - bench_inputs=bench_inputs, - prof_inputs=prof_inputs, - tables=tables, - num_benchmarks=num_benchmarks, - output_dir=output_dir, - func_to_benchmark=func_to_benchmark, - benchmark_func_kwargs=benchmark_func_kwargs, - pooling_configs=pooling_configs, - benchmark_unsharded_module=benchmark_unsharded, - ) - - gc.collect() - - benchmark_results.append(res) - - return benchmark_results diff --git a/torchrec/distributed/benchmark/examples/module_benchmark_config.yaml b/torchrec/distributed/benchmark/examples/module_benchmark_config.yaml new file mode 100644 index 000000000..9ec9feb45 --- /dev/null +++ b/torchrec/distributed/benchmark/examples/module_benchmark_config.yaml @@ -0,0 +1,27 @@ +# Example YAML configuration for module-level benchmarking +# Usage: python -m torchrec.distributed.benchmark.unified_benchmark_runner --yaml_config=module_benchmark_config.yaml + +UnifiedBenchmarkConfig: + benchmark_type: "module" + module_path: "torchrec.models.deepfm" + module_class: "SimpleDeepFMNNWrapper" + module_kwargs: + hidden_layer_size: 20 + deep_fm_dimension: 5 + +RunOptions: + world_size: 2 + num_batches: 10 + sharding_type: "table_wise" + compute_kernel: "fused" + input_type: "kjt" + planner_type: "embedding" + dense_optimizer: "SGD" + dense_lr: 0.1 + sparse_optimizer: "EXACT_ADAGRAD" + sparse_lr: 0.1 + +EmbeddingTablesConfig: + num_unweighted_features: 100 + num_weighted_features: 100 + embedding_feature_dim: 128 diff --git a/torchrec/distributed/benchmark/examples/pipeline_benchmark_config.yaml b/torchrec/distributed/benchmark/examples/pipeline_benchmark_config.yaml new file mode 100644 index 000000000..3a5e16a10 --- /dev/null +++ b/torchrec/distributed/benchmark/examples/pipeline_benchmark_config.yaml @@ -0,0 +1,38 @@ +# Example YAML configuration for pipeline-level benchmarking +# Usage: python -m torchrec.distributed.benchmark.unified_benchmark_runner --yaml_config=pipeline_benchmark_config.yaml + +UnifiedBenchmarkConfig: + benchmark_type: "pipeline" + +PipelineConfig: + pipeline: "sparse" + emb_lookup_stream: "data_dist" + apply_jit: false + +RunOptions: + world_size: 2 + num_batches: 10 + sharding_type: "table_wise" + compute_kernel: "fused" + input_type: "kjt" + planner_type: "embedding" + dense_optimizer: "SGD" + dense_lr: 0.1 + sparse_optimizer: "EXACT_ADAGRAD" + sparse_lr: 0.1 + +ModelSelectionConfig: + model_name: "test_sparse_nn" + batch_size: 8192 + num_float_features: 10 + feature_pooling_avg: 10 + use_offsets: false + long_kjt_indices: true + long_kjt_offsets: true + long_kjt_lengths: true + pin_memory: true + +EmbeddingTablesConfig: + num_unweighted_features: 100 + num_weighted_features: 100 + embedding_feature_dim: 128 diff --git a/torchrec/distributed/benchmark/unified_benchmark_runner.py b/torchrec/distributed/benchmark/unified_benchmark_runner.py new file mode 100644 index 000000000..13a4b94dd --- /dev/null +++ b/torchrec/distributed/benchmark/unified_benchmark_runner.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +""" +Unified benchmark runner that can handle both module and pipeline level benchmarking +from a single YAML configuration file. +""" + +from typing import Optional + +from torchrec.distributed.benchmark.benchmark_pipeline_utils import ( + BaseModelConfig, + create_model_config, +) +from torchrec.distributed.benchmark.benchmark_train_pipeline import ( + EmbeddingTablesConfig, + ModelSelectionConfig, + PipelineConfig, + run_module_benchmark, + run_pipeline, + RunOptions, + UnifiedBenchmarkConfig, +) +from torchrec.distributed.benchmark.benchmark_utils import cmd_conf + + +@cmd_conf +def main( + run_option: RunOptions, + table_config: EmbeddingTablesConfig, + unified_config: UnifiedBenchmarkConfig, + model_selection: ModelSelectionConfig, + pipeline_config: PipelineConfig, + model_config: Optional[BaseModelConfig] = None, +) -> None: + """ + Unified main function that routes to appropriate benchmark type. + + Args: + run_option: Configuration for running the benchmark + table_config: Configuration for embedding tables + unified_config: Unified configuration specifying benchmark type and module details + model_selection: Configuration for model selection (only needed for pipeline benchmarks) + pipeline_config: Configuration for pipeline (only needed for pipeline benchmarks) + model_config: Optional model configuration + """ + + print(f"Starting {unified_config.benchmark_type} benchmark...") + + if unified_config.benchmark_type == "module": + # Run module-level benchmark + if not unified_config.module_path or not unified_config.module_class: + raise ValueError( + "For module benchmarking, both module_path and module_class must be specified" + ) + + result = run_module_benchmark(unified_config, table_config, run_option) + print(f"Module benchmark completed: {result}") + + elif unified_config.benchmark_type == "pipeline": + # Run pipeline-level benchmark + if model_selection is None: + # Create default model selection config + model_selection = ModelSelectionConfig() + + if pipeline_config is None: + # Create default pipeline config + pipeline_config = PipelineConfig() + + if model_config is None: + model_config = create_model_config( + model_name=model_selection.model_name, + batch_size=model_selection.batch_size, + batch_sizes=model_selection.batch_sizes, + num_float_features=model_selection.num_float_features, + feature_pooling_avg=model_selection.feature_pooling_avg, + use_offsets=model_selection.use_offsets, + dev_str=model_selection.dev_str, + long_kjt_indices=model_selection.long_kjt_indices, + long_kjt_offsets=model_selection.long_kjt_offsets, + long_kjt_lengths=model_selection.long_kjt_lengths, + pin_memory=model_selection.pin_memory, + embedding_groups=model_selection.embedding_groups, + feature_processor_modules=model_selection.feature_processor_modules, + max_feature_lengths=model_selection.max_feature_lengths, + over_arch_clazz=model_selection.over_arch_clazz, + postproc_module=model_selection.postproc_module, + zch=model_selection.zch, + hidden_layer_size=model_selection.hidden_layer_size, + deep_fm_dimension=model_selection.deep_fm_dimension, + dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes, + over_arch_layer_sizes=model_selection.over_arch_layer_sizes, + ) + + result = run_pipeline(run_option, table_config, pipeline_config, model_config) + print(f"Pipeline benchmark completed: {result}") + + else: + raise ValueError( + f"Unknown benchmark_type: {unified_config.benchmark_type}. Must be 'module' or 'pipeline'" + ) + + +if __name__ == "__main__": + main()