diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 9dbfd142b..9ed10d07f 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -338,6 +338,16 @@ def _get_sharded_local_buckets_for_zero_collision( return sharded_local_buckets +@dataclass +class ShardParams: + optimizer_states: List[Optional[Tuple[torch.Tensor]]] + optimizer_states_keys: List[torch.Tensor] + local_metadata: List[ShardMetadata] + global_metadata: ShardedTensorMetadata + embedding_weights: List[torch.Tensor] + dtensor_metadata: List[DTensorMetadata] + + class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): def __init__( self, @@ -352,8 +362,7 @@ def __init__( self._emb_module: SSDTableBatchedEmbeddingBags = emb_module self._pg = pg - # TODO: support optimizer states checkpointing once FBGEMM support - # split_optimizer_states API + # Initializing all required variables # pyre-ignore [33] state: Dict[Any, Any] = {} @@ -361,11 +370,399 @@ def __init__( "params": [], "lr": emb_module.get_learning_rate(), } - params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {} + all_optimizer_states = emb_module.get_optimizer_state(None) + table_to_shard_params: Dict[str, ShardParams] = {} + + # Changing weights location to CPU + embedding_weights_by_table, _, _, _ = emb_module.split_embedding_weights() + for emb_table in config.embedding_tables: + emb_table.local_metadata.placement._device = torch.device("cpu") + + # [Step 1] Create ShardParams for every embedding table + for ( + table_config, + optimizer_states, + weight, + ) in itertools.zip_longest( + config.embedding_tables, + all_optimizer_states, + embedding_weights_by_table, + ): + # Creating a placeholder shardParam for every embedding table + if table_config.name not in table_to_shard_params: + table_to_shard_params[table_config.name] = ShardParams( + optimizer_states=[], + local_metadata=[], + embedding_weights=[], + dtensor_metadata=[], + global_metadata=ShardedTensorMetadata(), + optimizer_states_keys=[], + ) + + optimizer_state_values = None + if optimizer_states: + optimizer_state_values = tuple(optimizer_states.values()) + for optimizer_state_value in optimizer_state_values: + assert ( + table_config.local_rows == optimizer_state_value.size(0) + or optimizer_state_value.nelement() == 1 # single value state + ) + # Saving the optimizer keys for every table + table_to_shard_params[table_config.name].optimizer_states_keys.append( + optimizer_states.keys() + ) + + # Adding data to the shard params for every table + table_to_shard_params[table_config.name].optimizer_states.append( + optimizer_state_values + ) + table_to_shard_params[table_config.name].local_metadata.append( + table_config.local_metadata + ) + table_to_shard_params[table_config.name].dtensor_metadata.append( + table_config.dtensor_metadata + ) + table_to_shard_params[table_config.name].embedding_weights.append(weight) + table_to_shard_params[table_config.name].global_metadata = ( + table_config.global_metadata + ) + + # Loop through every table + seen_tables = set() + for table_config in config.embedding_tables: + if table_config.name in seen_tables: + continue + seen_tables.add(table_config.name) + shard_params: ShardParams = table_to_shard_params[table_config.name] + + local_weight_shards = [] + for local_weight, local_metadata in zip( + shard_params.embedding_weights, shard_params.local_metadata + ): + # Creating a shard for every tensor -> this will have the tensor and its metadata + local_weight_shards.append(Shard(local_weight, local_metadata)) + shard_params.global_metadata.tensor_properties.dtype = ( + local_weight.dtype + ) + shard_params.global_metadata.tensor_properties.requires_grad = ( + local_weight.requires_grad + ) + # Creating a Shard Tensor using all the above created shards + weight = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_weight_shards, + sharded_tensor_metadata=shard_params.global_metadata, + process_group=self._pg, + ) + param_key = table_config.name + ".weight" + + # Saving the shard tensor + state[weight] = {} + param_group["params"].append(weight) + params[param_key] = weight + + # Update sharding dimension and grid sharding for every embedding table + self.sharding_dim: int = ( + 1 if table_config.local_cols != table_config.embedding_dim else 0 + ) + + self.is_grid_sharded: bool = ( + True + if table_config.local_cols != table_config.embedding_dim + and table_config.local_rows != table_config.num_embeddings + else False + ) + + # Going through Optimizers + if all( + opt_state is not None for opt_state in shard_params.optimizer_states + ): + # Number of optimizers for the table + num_states: int = min( + # pyre-ignore + [len(opt_state) for opt_state in shard_params.optimizer_states] + ) + optimizer_state_keys = [] + if num_states > 0: + optimizer_state_keys = table_to_shard_params[ + table_config.name + ].optimizer_states_keys + + for cur_state_idx in range(0, num_states): + if cur_state_idx == 0: + # for backward compatibility + # If only one optimizer state is present, we assume it is the momentum1 state + cur_state_key = "momentum1" + else: + cur_state_key = optimizer_state_keys[cur_state_idx] + # Creating the ShardedTensor for the optimizer weights for the table + state[weight][f"{table_config.name}.{cur_state_key}"] = ( + self.get_sharded_optim_state( + cur_state_idx + 1, cur_state_key, shard_params, table_config + ) + ) + + logger.info("Completed initializing keyvalueembeddingfusedOptimizer") super().__init__(params, state, [param_group]) + def get_sharded_optim_state( + self, + momentum_idx: int, + state_key: str, + shard_params: ShardParams, + table_config: ShardedEmbeddingTable, + ) -> Union[ShardedTensor, DTensor]: + + momentum_local_shards: List[Shard] = [] + optimizer_sharded_tensor_metadata: ShardedTensorMetadata + + # Momentum idx is minimum 1 + # pyre-ignore [16] + optim_state = shard_params.optimizer_states[0][momentum_idx - 1] + if ( + optim_state.nelement() == 1 and state_key != "momentum1" + ): # special handling for backward compatibility, momentum1 is rowwise state for rowwise_adagrad + # single value state: one value per table + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = self.get_optimizer_single_value_shard_metadata_and_global_metadata( + shard_params.global_metadata, + optim_state, + ) + elif optim_state.dim() == 1: + # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1 + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = self.get_optimizer_rowwise_shard_metadata_and_global_metadata( + shard_params.global_metadata, + optim_state, + self.sharding_dim, + self.is_grid_sharded, + ) + else: + # pointwise state: param.shape == state.shape + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = self.get_optimizer_pointwise_shard_metadata_and_global_metadata( + shard_params.global_metadata, + optim_state, + ) + + for optimizer_state, table_shard_local_metadata in zip( + shard_params.optimizer_states, shard_params.local_metadata + ): + local_optimizer_shard_metadata = ( + table_shard_metadata_to_optimizer_shard_metadata[ + table_shard_local_metadata + ] + ) + momentum_local_shards.append( + Shard( + optimizer_state[momentum_idx - 1], + local_optimizer_shard_metadata, + ) + ) + + # Convert optimizer state to DTensor if enabled + if ( + table_config.dtensor_metadata is not None + and table_config.dtensor_metadata.mesh + ): + dtensor_metadata = table_config.dtensor_metadata + # if rowwise state we do Shard(0), regardless of how the table is sharded + if optim_state.dim() == 1: + stride = (1,) + placements = ( + (Replicate(), DTensorShard(0)) + if dtensor_metadata.mesh is not None + and dtensor_metadata.mesh.ndim == 2 + else (DTensorShard(0),) + ) + else: + stride = dtensor_metadata.stride + placements = dtensor_metadata.placements + + return DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[x.tensor for x in momentum_local_shards], + local_offsets=[ # pyre-ignore[6] + x.metadata.shard_offsets for x in momentum_local_shards + ], + ), + device_mesh=dtensor_metadata.mesh, + placements=placements, + shape=optimizer_sharded_tensor_metadata.size, + stride=stride, + run_check=False, + ) + else: + # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata. + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=momentum_local_shards, + sharded_tensor_metadata=optimizer_sharded_tensor_metadata, + process_group=self._pg, + ) + + def get_optimizer_single_value_shard_metadata_and_global_metadata( + self, + table_global_metadata: ShardedTensorMetadata, + optimizer_state: torch.Tensor, + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + table_shard_metadata_to_optimizer_shard_metadata = {} + for offset, table_shard_metadata in enumerate(table_global_shards_metadata): + + # pyre-ignore [16] + table_shard_metadata.placement._device = optimizer_state.device + # Creating shardMetaData + table_shard_metadata_to_optimizer_shard_metadata[table_shard_metadata] = ( + ShardMetadata( + shard_sizes=[1], # single value optimizer state + shard_offsets=[offset], # offset increases by 1 for each shard + placement=table_shard_metadata.placement, + ) + ) + + tensor_properties = TensorProperties( + dtype=optimizer_state.dtype, + layout=optimizer_state.layout, + requires_grad=False, + ) + # Creating ShardedTensorMetaData + single_value_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=torch.Size([len(table_global_shards_metadata)]), + tensor_properties=tensor_properties, + ) + + return ( + table_shard_metadata_to_optimizer_shard_metadata, + single_value_optimizer_st_metadata, + ) + + def get_optimizer_rowwise_shard_metadata_and_global_metadata( + self, + table_global_metadata: ShardedTensorMetadata, + optimizer_state: torch.Tensor, + sharding_dim: int, + is_grid_sharded: bool = False, + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + if sharding_dim == 1: + # column-wise sharding + # sort the metadata based on column offset and + # we construct the momentum tensor in row-wise sharded way + table_global_shards_metadata = sorted( + table_global_shards_metadata, + key=lambda shard: shard.shard_offsets[1], + ) + + table_shard_metadata_to_optimizer_shard_metadata = {} + rolling_offset = 0 + for idx, table_shard_metadata in enumerate(table_global_shards_metadata): + offset = table_shard_metadata.shard_offsets[0] + + if is_grid_sharded: + # we use a rolling offset to calculate the current offset for shard to account for uneven row wise case for our shards + offset = rolling_offset + rolling_offset += table_shard_metadata.shard_sizes[0] + elif sharding_dim == 1: + # for column-wise sharding, we still create row-wise sharded metadata for optimizer + # manually create a row-wise offset + offset = idx * table_shard_metadata.shard_sizes[0] + + # pyre-ignore [16] + table_shard_metadata.placement._device = optimizer_state.device + table_shard_metadata_to_optimizer_shard_metadata[table_shard_metadata] = ( + ShardMetadata( + shard_sizes=[table_shard_metadata.shard_sizes[0]], + shard_offsets=[offset], + placement=table_shard_metadata.placement, + ) + ) + + tensor_properties = TensorProperties( + dtype=optimizer_state.dtype, + layout=optimizer_state.layout, + requires_grad=False, + ) + len_rw_shards = ( + len(table_shard_metadata_to_optimizer_shard_metadata) + if sharding_dim == 1 and not is_grid_sharded + else 1 + ) + # for grid sharding, the row dimension is replicated CW shard times + grid_shard_nodes = ( + len(table_global_shards_metadata) // get_node_group_size() + if is_grid_sharded + else 1 + ) + rowwise_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=torch.Size( + [table_global_metadata.size[0] * len_rw_shards * grid_shard_nodes] + ), + tensor_properties=tensor_properties, + ) + + return ( + table_shard_metadata_to_optimizer_shard_metadata, + rowwise_optimizer_st_metadata, + ) + + def get_optimizer_pointwise_shard_metadata_and_global_metadata( + self, + table_global_metadata: ShardedTensorMetadata, + optimizer_state: torch.Tensor, + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + table_shard_metadata_to_optimizer_shard_metadata = {} + + for table_shard_metadata in table_global_shards_metadata: + # pyre-ignore [16] + table_shard_metadata.placement._device = optimizer_state.device + table_shard_metadata_to_optimizer_shard_metadata[table_shard_metadata] = ( + ShardMetadata( + shard_sizes=table_shard_metadata.shard_sizes, + shard_offsets=table_shard_metadata.shard_offsets, + placement=table_shard_metadata.placement, + ) + ) + tensor_properties = TensorProperties( + dtype=optimizer_state.dtype, + layout=optimizer_state.layout, + requires_grad=False, + ) + pointwise_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=table_global_metadata.size, + tensor_properties=tensor_properties, + ) + + return ( + table_shard_metadata_to_optimizer_shard_metadata, + pointwise_optimizer_st_metadata, + ) + def zero_grad(self, set_to_none: bool = False) -> None: # pyre-ignore [16] self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index 2d253cb47..45f373e42 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -15,6 +15,7 @@ import torch.distributed as dist import torch.nn as nn from fbgemm_gpu.split_embedding_configs import EmbOptimType +from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags from hypothesis import given, settings, strategies as st, Verbosity from torchrec.distributed.batched_embedding_kernel import ( KeyValueEmbedding, @@ -50,6 +51,7 @@ EmbeddingConfig, ) from torchrec.optim import RowWiseAdagrad +from torchrec.optim.keyed import CombinedOptimizer def _load_split_embedding_weights( @@ -94,6 +96,65 @@ def _create_tables(self) -> None: for i in range(num_features) ] + @staticmethod + def _compare_ssd_fused_optimizer(m1: DistributedModelParallel) -> None: + """ + Util function to compare optimizer weights from SSD TBE and DistributedModelParallel. + """ + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + for lookup1 in m1.module.sparse.ebc._lookups: + for emb_module1 in lookup1._emb_modules: + ssd_emb_modules = {KeyValueEmbeddingBag, KeyValueEmbedding} + if type(emb_module1) in ssd_emb_modules: + emb_module1.create_rocksdb_hard_link_snapshot() + # Getting the optimizer weights from the embedding module + optimizer_weights_from_emb_bag = ( + emb_module1._emb_module.get_optimizer_state( + sorted_id_tensor=None + ) + ) + # Getting the optimizer weights from the DistributedModelParallel, which is the combinedfused optimizer + optimizer_weights_from_optim = m1._optim.state + # All assert statements for flow verification + # Assumption: + # 1. Embedding module is SSDTableBatchedEmbeddingBags + # 2. Optimizer is CombinedOptimizer + assert isinstance( + m1._optim, + CombinedOptimizer, + ), f"Optimizer class should only be CombinedOptimizer. but got type: {type(m1._optim)}" + assert isinstance( + emb_module1._emb_module, SSDTableBatchedEmbeddingBags + ), f"Embedding module should only be SSDTableBatchedEmbeddingBags. but got type: {type(emb_module1._emb_module)}" + + # Checking if the optimizer weights are not none + assert ( + optimizer_weights_from_emb_bag is not None + and optimizer_weights_from_optim is not None + ), "Expect optimizer weights to be not None." + + optimizer_weights_from_optim_list = list( + optimizer_weights_from_optim.values() + ) + for weight_from_emb_bag, weight_from_optim in zip( + optimizer_weights_from_emb_bag, + optimizer_weights_from_optim_list, + ): + optim_weight_from_emb_tensor = list( + weight_from_emb_bag.values() + )[0] + + optim_weight_from_optim_tensor = [ + shard.tensor + for shard in (list(weight_from_optim.values()))[ + 0 + ].local_shards() + ][0] + assert torch.equal( + optim_weight_from_emb_tensor, + optim_weight_from_optim_tensor, + ), "Expect optimizer weights from emb and optim to be equal." + @staticmethod def _copy_ssd_emb_modules( m1: DistributedModelParallel, m2: DistributedModelParallel @@ -462,6 +523,76 @@ def test_ssd_fused_optimizer( ) self._compare_models(base_model, test_model, is_deterministic=is_deterministic) + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.KEY_VALUE.value, + ] + ), + sharding_type=st.sampled_from( + [ + ShardingType.TABLE_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.COLUMN_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + # TODO: uncomment when ssd ckpt support cw sharding + # ShardingType.TABLE_COLUMN_WISE.value, + ] + ), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_ssd_key_value_fused_optimizer( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Purpose of this test is to make sure the initialization of the KeyValueBatchedFusedOptimizer works as expected for SSD Offloading use-cases + """ + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), + ) + for _, table in enumerate(self.tables) + } + + base_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.2, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, _ = self._generate_dmps_and_batch( + base_sharders, # pyre-ignore + constraints=constraints, + ) + model, _ = models + + self._compare_ssd_fused_optimizer(model) + @unittest.skipIf( not torch.cuda.is_available(), "Not enough GPUs, this test requires at least one GPU",