diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 2336865bd..80a897466 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -189,17 +189,31 @@ def enumerate( sharder.compute_kernels(sharding_type, self._compute_device), sharding_type, ): - ( - shard_sizes, - shard_offsets, - ) = calculate_shard_sizes_and_offsets( - tensor=param, - world_size=self._world_size, - local_world_size=self._local_world_size, - sharding_type=sharding_type, - col_wise_shard_dim=col_wise_shard_dim, - device_memory_sizes=self._device_memory_sizes, - ) + try: + ( + shard_sizes, + shard_offsets, + ) = calculate_shard_sizes_and_offsets( + tensor=param, + world_size=self._world_size, + local_world_size=self._local_world_size, + sharding_type=sharding_type, + col_wise_shard_dim=col_wise_shard_dim, + device_memory_sizes=self._device_memory_sizes, + ) + except ZeroDivisionError as e: + # Re-raise with additional context about the table and module + raise ValueError( + f"Failed to calculate sharding plan for table '{name}': {str(e)} " + f"Context: table_name='{name}', module_path='{child_path}', " + f"module_type='{type(child_module).__name__}', " + f"sharder='{sharder.__class__.__name__}', " + f"tensor.shape={param.shape}, sharding_type='{sharding_type}', " + f"compute_kernel='{compute_kernel}', world_size={self._world_size}, " + f"local_world_size={self._local_world_size}, " + f"col_wise_shard_dim={col_wise_shard_dim}, " + f"compute_device='{self._compute_device}'" + ) from e dependency = None if isinstance(child_module, EmbeddingTower): dependency = child_path diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index f0c8a847e..f67955715 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -268,23 +268,7 @@ def _calculate_cw_shard_sizes_and_offsets( rows: int, col_wise_shard_dim: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[int]]]: - block_size: int = min( - ( - _find_base_dim(col_wise_shard_dim, columns) - if col_wise_shard_dim - else _find_base_dim(MIN_CW_DIM, columns) - ), - columns, - ) - - if columns % block_size != 0: - warnings.warn( - f"Dim of {columns} cannot be evenly divided with column wise shard" - "dim {col_wise_shard_dim}, overriding block_size to embedding_dim={columns}", - UserWarning, - stacklevel=2, - ) - block_size = columns + block_size = _get_block_size_for_cw_shard(columns, col_wise_shard_dim) num_col_wise_shards, _residual = divmod(columns, block_size)