diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 92a7db0ce..7c03d1bf9 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -8,14 +8,17 @@ # pyre-strict +import copy import random import unittest +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Type import hypothesis.strategies as st import torch +from fbgemm_gpu.split_embedding_configs import EmbOptimType from hypothesis import assume, given, settings, Verbosity @@ -27,9 +30,25 @@ KeyedJaggedTensor, optim as trec_optim, ) + +from torchrec.distributed.benchmark.benchmark_pipeline_utils import ( + BaseModelConfig, + create_model_config, + generate_data, + generate_pipeline, + generate_planner, + generate_sharded_model_and_optimizer, + generate_tables, +) +from torchrec.distributed.benchmark.benchmark_train_pipeline import ( + PipelineConfig, + RunOptions, +) +from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig +from torchrec.distributed.planner import Topology from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta from torchrec.distributed.sharding_plan import ( @@ -44,6 +63,7 @@ MultiProcessTestBase, ) from torchrec.distributed.test_utils.test_input import ModelInput +from torchrec.distributed.test_utils.test_model import TestOverArchLarge from torchrec.distributed.test_utils.test_model_parallel import ModelParallelTestShared from torchrec.distributed.test_utils.test_sharding import ( copy_state_dict, @@ -510,7 +530,7 @@ def test_dynamic_sharding_ebc_cw( @skip_if_asan_class class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): @unittest.skipIf( - torch.cuda.device_count() <= 1, + torch.cuda.device_count() <= 8, "Not enough GPUs, this test requires at least two GPUs", ) @given( # Pyre-ignore @@ -616,7 +636,6 @@ def test_sharding( data_type=data_type, sharding_type=sharding_type_e, random_seed=random_seed, - world_size=world_size, ) @@ -695,3 +714,425 @@ def test_output_sharding_plan_delta(self) -> None: ) # NOTE there are other attributes to test for equivalence in ParameterSharding type # but the ones included here are the most important. + + +def _test_pipeline_resharding( + rank: int, + world_size: int, + backend: str, + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + weighted_tables: List[EmbeddingBagConfig], + model_config: BaseModelConfig, + pipeline_config: PipelineConfig, + kjt_input_per_rank: List[KeyedJaggedTensor], + module_sharding_plan: EmbeddingModuleShardingPlan, + new_module_sharding_plan: EmbeddingModuleShardingPlan, + local_size: Optional[int] = None, +) -> None: + trec_dist.comm_ops.set_gradient_division(False) + torch.autograd.set_detect_anomaly(True) + run_option = RunOptions() + with MultiProcessContext( + rank, world_size, backend, local_size, use_deterministic_algorithms=False + ) as ctx: + unsharded_model = model_config.generate_model( + tables=tables, + weighted_tables=weighted_tables, + dense_device=ctx.device, + ) + + # Create a topology for sharding + topology = Topology( + local_world_size=get_local_size(world_size), + world_size=world_size, + compute_device=ctx.device.type, + ) + + batch_sizes = model_config.batch_sizes + + if batch_sizes is None: + batch_sizes = [model_config.batch_size] * run_option.num_batches + else: + assert ( + len(batch_sizes) == run_option.num_batches + ), "The length of batch_sizes must match the number of batches." + + # Create a planner for sharding based on the specified type + planner = generate_planner( + planner_type=run_option.planner_type, + topology=topology, + tables=tables, + weighted_tables=weighted_tables, + sharding_type=run_option.sharding_type, + compute_kernel=run_option.compute_kernel, + batch_sizes=batch_sizes, + pooling_factors=run_option.pooling_factors, + num_poolings=run_option.num_poolings, + ) + bench_inputs = generate_data( + tables=tables, + weighted_tables=weighted_tables, + model_config=model_config, + batch_sizes=batch_sizes, + ) + + # Prepare fused_params for sparse optimizer + fused_params = { + "optimizer": getattr(EmbOptimType, run_option.sparse_optimizer.upper()), + "learning_rate": run_option.sparse_lr, + } + + # Add momentum and weight_decay to fused_params if provided + if run_option.sparse_momentum is not None: + fused_params["momentum"] = run_option.sparse_momentum + + if run_option.sparse_weight_decay is not None: + fused_params["weight_decay"] = run_option.sparse_weight_decay + + sharded_model, optimizer = generate_sharded_model_and_optimizer( + model=unsharded_model, + sharding_type=run_option.sharding_type.value, + kernel_type=run_option.compute_kernel.value, + # pyre-ignore + pg=ctx.pg, + device=ctx.device, + fused_params=fused_params, + dense_optimizer=run_option.dense_optimizer, + dense_lr=run_option.dense_lr, + dense_momentum=run_option.dense_momentum, + dense_weight_decay=run_option.dense_weight_decay, + planner=planner, + ) + + pipeline = generate_pipeline( + pipeline_type=pipeline_config.pipeline, + emb_lookup_stream=pipeline_config.emb_lookup_stream, + model=sharded_model, + opt=optimizer, + device=ctx.device, + apply_jit=pipeline_config.apply_jit, + ) + pipeline.progress(iter(bench_inputs)) + + dataloader = iter(bench_inputs) + i = 0 + while True: + try: + if i == 3: + # Extract existing sharding plan + existing_sharding_plan = pipeline._model.module.sparse.ebc.module_sharding_plan # pyre-ignore + fqn_to_local_shards = "sparse.ebc" + # Modify existing sharding plan - Hard code + sharding_param = copy.deepcopy(existing_sharding_plan["table_0"]) + new_device = 1 if sharding_param.ranks[0] == 0 else 0 + sharding_param.ranks = [new_device] + sharding_param.sharding_spec.shards[0].placement = ( + torch.distributed._remote_device( + f"rank:{new_device}/cuda:{new_device}" + ) + ) + + new_sharding_plan = {} + new_sharding_plan["table_0"] = sharding_param + # Reshard + pipeline.progress_with_reshard( # pyre-ignore + dataloader_iter=dataloader, + reshard_params=new_sharding_plan, + sharded_module_fqn=fqn_to_local_shards, + ) + i += 1 + else: + pipeline.progress(dataloader) + i += 1 + except StopIteration: + break + + +@dataclass +class RunOptions: + """ + Configuration options for running sparse neural network benchmarks. + + This class defines the parameters that control how the benchmark is executed, + including distributed training settings, batch configuration, and profiling options. + + Args: + world_size (int): Number of processes/GPUs to use for distributed training. + Default is 2. + num_batches (int): Number of batches to process during the benchmark. + Default is 10. + sharding_type (ShardingType): Strategy for sharding embedding tables across devices. + Default is ShardingType.TABLE_WISE (entire tables are placed on single devices). + compute_kernel (EmbeddingComputeKernel): Compute kernel to use for embedding tables. + Default is EmbeddingComputeKernel.FUSED. + input_type (str): Type of input format to use for the model. + Default is "kjt" (KeyedJaggedTensor). + profile (str): Directory to save profiling results. If empty, profiling is disabled. + Default is "" (disabled). + planner_type (str): Type of sharding planner to use. Options are: + - "embedding": EmbeddingShardingPlanner (default) + - "hetero": HeteroEmbeddingShardingPlanner + pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table. + This is the average number of values each sample has for the feature. + num_poolings (Optional[List[float]]): Number of poolings for each feature of the table. + dense_optimizer (str): Optimizer to use for dense parameters. + Default is "SGD". + dense_lr (float): Learning rate for dense parameters. + Default is 0.1. + sparse_optimizer (str): Optimizer to use for sparse parameters. + Default is "EXACT_ADAGRAD". + sparse_lr (float): Learning rate for sparse parameters. + Default is 0.1. + """ + + world_size: int = 2 + num_batches: int = 10 + sharding_type: ShardingType = ShardingType.TABLE_WISE + compute_kernel: EmbeddingComputeKernel = EmbeddingComputeKernel.DENSE + input_type: str = "kjt" + profile: str = "" + planner_type: str = "embedding" + pooling_factors: Optional[List[float]] = None + num_poolings: Optional[List[float]] = None + dense_optimizer: str = "SGD" + dense_lr: float = 0.1 + dense_momentum: Optional[float] = None + dense_weight_decay: Optional[float] = None + sparse_optimizer: str = "EXACT_ADAGRAD" + sparse_lr: float = 0.1 + sparse_momentum: Optional[float] = None + sparse_weight_decay: Optional[float] = None + export_stacks: bool = False + + +@dataclass +class EmbeddingTablesConfig: + """ + Configuration for embedding tables. + + This class defines the parameters for generating embedding tables with both weighted + and unweighted features. + + Args: + num_unweighted_features (int): Number of unweighted features to generate. + Default is 100. + num_weighted_features (int): Number of weighted features to generate. + Default is 100. + embedding_feature_dim (int): Dimension of the embedding vectors. + Default is 128. + """ + + num_unweighted_features: int = 100 + num_weighted_features: int = 100 + embedding_feature_dim: int = 128 + + +@dataclass +class PipelineConfig: + """ + Configuration for training pipelines. + + This class defines the parameters for configuring the training pipeline. + + Args: + pipeline (str): The type of training pipeline to use. Options include: + - "base": Basic training pipeline + - "sparse": Pipeline optimized for sparse operations + - "fused": Pipeline with fused sparse distribution + - "semi": Semi-synchronous training pipeline + - "prefetch": Pipeline with prefetching for sparse distribution + Default is "base". + emb_lookup_stream (str): The stream to use for embedding lookups. + Only used by certain pipeline types (e.g., "fused"). + Default is "data_dist". + apply_jit (bool): Whether to apply JIT (Just-In-Time) compilation to the model. + Default is False. + """ + + pipeline: str = "sparse" # "base", + emb_lookup_stream: str = "data_dist" + apply_jit: bool = False + + +class MultiRankPipelineDynamicShardingTest(MultiProcessTestBase): + @unittest.skipIf( + torch.cuda.device_count() <= 4, + "Not enough GPUs, this test requires at least two GPUs", + ) + @given( # pyre-ignore + # Model selection config parameters + model_name=st.sampled_from( + [ + "test_sparse_nn", + # "test_tower_sparse_nn", + # "test_tower_collection_sparse_nn", + # "deepfm", + # "dlrm", + ] + ), + batch_size=st.integers(10, 9000), + batch_sizes=st.lists(st.integers(10, 100), min_size=10, max_size=10), + num_float_features=st.integers(5, 20), + feature_pooling_avg=st.integers(5, 20), + use_offsets=st.booleans(), + dev_str=st.just(""), + long_kjt_indices=st.booleans(), + long_kjt_offsets=st.booleans(), + long_kjt_lengths=st.booleans(), + pin_memory=st.booleans(), + zch=st.booleans(), + hidden_layer_size=st.integers(10, 50), + deep_fm_dimension=st.integers(3, 10), + dense_arch_layer_sizes=st.lists(st.integers(10, 200), min_size=2, max_size=2), + over_arch_layer_sizes=st.lists(st.integers(1, 10), min_size=2, max_size=2), + # PipelineConfig parameters + pipeline=st.sampled_from( + ["sparse"] + ), # "base", "sparse", "fused", "semi", "prefetch"s + emb_lookup_stream=st.sampled_from(["data_dist", "compute"]), + apply_jit=st.booleans(), + # EmbeddingTablesConfig parameters + num_unweighted_features=st.integers(50, 200), + num_weighted_features=st.integers(50, 200), + embedding_feature_dim=st.integers(64, 256), + # RunOptions parameters + num_batches=st.integers(5, 20), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE, + ShardingType.COLUMN_WISE, + ShardingType.ROW_WISE, + ShardingType.DATA_PARALLEL, + ] + ), + compute_kernel=st.sampled_from( + [ + EmbeddingComputeKernel.DENSE, + # TODO: Move to CPU + # EmbeddingComputeKernel.FUSED, + # EmbeddingComputeKernel.FUSED_UVM, + # EmbeddingComputeKernel.FUSED_UVM_CACHING, + ] + ), + input_type=st.just("kjt"), + profile=st.just(""), + planner_type=st.sampled_from(["embedding", "hetero"]), + dense_optimizer=st.sampled_from(["SGD", "Adam", "Adagrad"]), + dense_lr=st.floats(min_value=0.01, max_value=1.0), + sparse_optimizer=st.sampled_from( + ["EXACT_ADAGRAD", "EXACT_ROWWISE_ADAGRAD", "EXACT_SGD"] + ), + sparse_lr=st.floats(min_value=0.01, max_value=1.0), + export_stacks=st.booleans(), + world_size=st.sampled_from([2, 4]), # 8 + ) + @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) + def test_pipeline_resharding( + self, + # ModelSelectionConfig parameters + model_name: str, + batch_size: int, + batch_sizes: List[int], + num_float_features: int, + feature_pooling_avg: int, + use_offsets: bool, + dev_str: str, + long_kjt_indices: bool, + long_kjt_offsets: bool, + long_kjt_lengths: bool, + pin_memory: bool, + zch: bool, + hidden_layer_size: int, + deep_fm_dimension: int, + dense_arch_layer_sizes: List[int], + over_arch_layer_sizes: List[int], + # PipelineConfig parameters + pipeline: str, + emb_lookup_stream: str, + apply_jit: bool, + # EmbeddingTablesConfig parameters + num_unweighted_features: int, + num_weighted_features: int, + embedding_feature_dim: int, + # RunOptions parameters + num_batches: int, + sharding_type: ShardingType, + compute_kernel: EmbeddingComputeKernel, + input_type: str, + profile: str, + planner_type: str, + dense_optimizer: str, + dense_lr: float, + sparse_optimizer: str, + sparse_lr: float, + export_stacks: bool, + world_size: int, + ) -> None: + # Create run options + run_options = RunOptions( + world_size=world_size, + num_batches=num_batches, + sharding_type=sharding_type, + compute_kernel=compute_kernel, + input_type=input_type, + profile=profile, + planner_type=planner_type, + dense_optimizer=dense_optimizer, + dense_lr=dense_lr, + sparse_optimizer=sparse_optimizer, + sparse_lr=sparse_lr, + export_stacks=export_stacks, + ) + + # Generate tables using embedding tables config parameters + tables, weighted_tables = generate_tables( + num_unweighted_features=num_unweighted_features, + num_weighted_features=num_weighted_features, + embedding_feature_dim=embedding_feature_dim, + ) + + # Create model config + model_config = create_model_config( + model_name=model_name, + batch_size=batch_size, + batch_sizes=batch_sizes, + num_float_features=num_float_features, + feature_pooling_avg=feature_pooling_avg, + use_offsets=use_offsets, + dev_str=dev_str, + long_kjt_indices=long_kjt_indices, + long_kjt_offsets=long_kjt_offsets, + long_kjt_lengths=long_kjt_lengths, + pin_memory=pin_memory, + embedding_groups=None, + feature_processor_modules=None, + max_feature_lengths=None, + over_arch_clazz=TestOverArchLarge, + postproc_module=None, + zch=zch, + hidden_layer_size=hidden_layer_size, + deep_fm_dimension=deep_fm_dimension, + dense_arch_layer_sizes=dense_arch_layer_sizes, + over_arch_layer_sizes=over_arch_layer_sizes, + ) + + # Create pipeline config + pipeline_config = PipelineConfig( + pipeline=pipeline, + emb_lookup_stream=emb_lookup_stream, + apply_jit=apply_jit, + ) + self._run_multi_process_test( + callable=_test_pipeline_resharding, + world_size=world_size, + tables=tables, + weighted_tables=weighted_tables, + pipeline_config=pipeline_config, + model_config=model_config, + initial_state_dict=None, + kjt_input_per_rank=None, + backend="nccl", + module_sharding_plan=None, + new_module_sharding_plan=None, + ) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index c26f16983..301f728f6 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -66,7 +66,7 @@ DataLoadingThread, use_context_for_postprocs, ) -from torchrec.distributed.types import Awaitable +from torchrec.distributed.types import Awaitable, ParameterSharding from torchrec.pt2.checks import is_torchdynamo_compiling from torchrec.pt2.utils import default_pipeline_input_transformer from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -657,7 +657,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check self._wait_for_batch() - if len(self.batches) >= 2: + if len(self.batches) >= 2: # at 4 - will be only 1 self.batches left? # invoke splits all_to_all comms (first part of input_dist) self.start_sparse_data_dist(self.batches[1], self.contexts[1]) @@ -696,6 +696,81 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: self.dequeue_batch() return output + def progress_with_reshard( + self, + dataloader_iter: Iterator[In], + reshard_params: Dict[str, ParameterSharding], + sharded_module_fqn: Optional[str] = None, + ) -> Out: + """ + As resharding will affect the tensor placements. Will temporarily undo pipeline overlap + """ + # Assume pipeline batches are not empty: + # # attach the model just in case the user forgets to call it, especially when the user + # # pauses the pipeline.progress and detach the model for other purpose. + # if not self._model_attached: + # self.attach(self._model) + + # # fill the pipeline is only needed for the beginning when the pipeline (batches) is empty + # self.fill_pipeline(dataloader_iter) + + # Assume not last batch + # # here is the expected stop after exhausting all batches + if not self.batches: + raise StopIteration + # import fbvscode + + # fbvscode.set_trace() + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) + self._set_module_context(self.contexts[0]) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + # wait for batches[0] being available on device, this should always be completed since + # the input_dist of batches[0] has be invoked in previous iter. TODO: fact check + self._wait_for_batch() + + # But reshard after this. + # forward + with record_function("## forward ##"): + losses, output = self._model_fwd(self.batches[0]) + + if self._model.training: + # backward + self._backward(losses) + + self.sync_embeddings( + self._model, + self._dmp_collection_sync_interval_batches, + self.contexts[0], + ) + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + # Reshard + self._model.reshard( # pyre-ignore + sharded_module_fqn=sharded_module_fqn, + changed_shard_to_params=reshard_params, + ) + + # Need to reshard before this. + if len(self.batches) >= 2: + # invoke splits all_to_all comms (first part of input_dist) + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) + self.wait_sparse_data_dist(self.contexts[1]) + # Assume _enqueue_batch_after_forward is True - current implementation + # self._enqueue_batch_after_forward not relevant here because - this has no pipelining + self.enqueue_batch( + dataloader_iter + ) # TODO: say for batch i+1, the enqueue batch was called in i-1 (which did not have resharded plan...) + self.dequeue_batch() + return output + def _create_context(self) -> TrainPipelineContext: context = self._context_type(index=self._next_index, version=1) self._next_index += 1