From 63fab0c20bc0e3e92012509b52467d3667768e7a Mon Sep 17 00:00:00 2001 From: Zijing Liu Date: Mon, 11 Aug 2025 14:43:08 -0700 Subject: [PATCH] Add HH support for col-wise sharding (#3269) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3269 As title, allows `col_wise` to take a a list of customized device types so that for each shard, we are able to specify which device to place it on to. For example, from `([10, 20, 30], "cpu")` to `([10, 20, 30], ["cpu", "cpu", "gpu"])` Reviewed By: gyllstromk, faran928 Differential Revision: D79600418 --- torchrec/distributed/sharding_plan.py | 41 ++- .../distributed/tests/test_sharding_plan.py | 235 ++++++++++++++++++ 2 files changed, 274 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index f0c8a847e..fcb0a6074 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -626,13 +626,21 @@ def column_wise( ranks: Optional[List[int]] = None, size_per_rank: Optional[List[int]] = None, compute_kernel: Optional[str] = None, + device_types: Optional[List[str]] = None, ) -> ParameterShardingGenerator: """ Returns a generator of ParameterShardingPlan for `ShardingType::COLUMN_WISE` for construct_module_sharding_plan. - Table will the sharded column-wise evenly across specified ranks (and can reuse ranks). + Table will be sharded column-wise across specified ranks (and can reuse ranks). Args: - ranks (List[int]): ranks to place columns + ranks (Optional[List[int]]): Ranks to place columns. Required if size_per_rank is None. + size_per_rank (Optional[List[int]]): List specifying the number of columns per rank. + If provided, the columns will be distributed according to these sizes. + device_types (Optional[List[str]]): List of device types (e.g., "cpu", "cuda") for each shard. + Used to specify different device placements for different shards. + + Returns: + ParameterShardingGenerator: A function that generates parameter sharding configuration. Example:: @@ -652,6 +660,23 @@ def _parameter_sharding_generator( device_type: str, sharder: ModuleSharder[nn.Module], ) -> ParameterSharding: + """ + Internal function that generates the parameter sharding configuration. + + Args: + param: The parameter tensor to be sharded. + local_size: Number of devices in the local process group. + world_size: Total number of devices across all process groups. + device_type: Type of device (e.g., "cuda", "cpu"). + sharder: The module sharder instance. + + Returns: + ParameterSharding: The sharding configuration for the parameter. + + Raises: + ValueError: If the parameter cannot be evenly divided across ranks or + if the specified sizes cannot fit the tensor. + """ if size_per_rank is None: assert ranks is not None @@ -688,6 +713,17 @@ def _parameter_sharding_generator( f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {size_per_rank}" ) + placements: List[str] = [] + if device_types is not None: + assert len(device_types) == len( + size_offset_ranks + ), "device_types must be the same length as ranks" + index: int = 0 + for device in device_types: + placements.append(placement_helper(device, index, index)) + if device != "cpu": + index += 1 + return _get_parameter_sharding( param, ShardingType.COLUMN_WISE.value, @@ -695,6 +731,7 @@ def _parameter_sharding_generator( local_size, device_type, sharder, + placements=placements if placements else None, compute_kernel=compute_kernel, ) diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 41460582d..02f64e859 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -887,6 +887,241 @@ def test_column_wise(self, data_type: DataType) -> None: } self.assertDictEqual(expected, module_sharding_plan) + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_column_wise_size_per_rank(self, data_type: DataType) -> None: + """Test column_wise sharding with custom size_per_rank parameter.""" + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=100, # Total columns that will be split as [30, 40, 30] + num_embeddings=1024, + data_type=data_type, + ) + ] + + # Test uneven column distribution: rank 0 gets 30 cols, rank 1 gets 40 cols, rank 2 gets 30 cols + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": column_wise(size_per_rank=[30, 40, 30]), + }, + local_size=3, + world_size=3, + device_type="cuda", + ) + + expected = { + "table_0": ParameterSharding( + sharding_type="column_wise", + compute_kernel="dense", + ranks=[0, 1, 2], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[1024, 30], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 30], + shard_sizes=[1024, 40], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[0, 70], + shard_sizes=[1024, 30], + placement="rank:2/cuda:2", + ), + ] + ), + ), + } + self.assertDictEqual(expected, module_sharding_plan) + + # pyre-fixme[56] + @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16])) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_column_wise_device_types(self, data_type: DataType) -> None: + """Test column_wise sharding with mixed device types.""" + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=64, + num_embeddings=1024, + data_type=data_type, + ) + ] + + # Test mixed device types: cpu, cuda, cpu, cuda + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": column_wise( + ranks=[0, 1, 2, 3], + device_types=["cpu", "cuda", "cpu", "cuda"], + ), + }, + local_size=4, + world_size=4, + device_type="cuda", + ) + + expected = { + "table_0": ParameterSharding( + sharding_type="column_wise", + compute_kernel="dense", + ranks=[0, 1, 2, 3], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[1024, 16], + placement="rank:0/cpu", + ), + ShardMetadata( + shard_offsets=[0, 16], + shard_sizes=[1024, 16], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 32], + shard_sizes=[1024, 16], + placement="rank:0/cpu", + ), + ShardMetadata( + shard_offsets=[0, 48], + shard_sizes=[1024, 16], + placement="rank:1/cuda:1", + ), + ] + ), + ), + } + self.assertDictEqual(expected, module_sharding_plan) + + def test_column_wise_size_per_rank_insufficient_columns(self) -> None: + """Test that column_wise raises error when size_per_rank doesn't cover all columns.""" + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=100, + num_embeddings=1024, + data_type=DataType.FP32, + ) + ] + + with self.assertRaises(ValueError) as context: + construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": column_wise( + size_per_rank=[30, 40] + ), # Only covers 70/100 columns + }, + local_size=2, + world_size=2, + device_type="cuda", + ) + + self.assertIn( + "Cannot fit tensor of (1024, 100) into sizes_ranks_placements = [30, 40]", + str(context.exception), + ) + + def test_column_wise_size_per_rank_with_device_types(self) -> None: + """Test column_wise sharding with both size_per_rank and device_types parameters.""" + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=80, # Total columns that will be split as [20, 30, 30] + num_embeddings=512, + data_type=DataType.FP32, + ) + ] + + # Test combining custom column sizes with mixed device types + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": column_wise( + size_per_rank=[20, 30, 30], + device_types=["cpu", "cuda", "cpu"], + ), + }, + local_size=3, + world_size=3, + device_type="cuda", + ) + + expected = { + "table_0": ParameterSharding( + sharding_type="column_wise", + compute_kernel="dense", + ranks=[0, 1, 2], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[512, 20], + placement="rank:0/cpu", + ), + ShardMetadata( + shard_offsets=[0, 20], + shard_sizes=[512, 30], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 50], + shard_sizes=[512, 30], + placement="rank:0/cpu", + ), + ] + ), + ), + } + self.assertDictEqual(expected, module_sharding_plan) + + def test_column_wise_uneven_division_error(self) -> None: + """Test that column_wise raises error when columns can't be evenly divided across ranks.""" + + embedding_bag_config = [ + EmbeddingBagConfig( + name="table_0", + feature_names=["feature_0"], + embedding_dim=65, # Cannot be evenly divided by 2 + num_embeddings=1024, + data_type=DataType.FP32, + ) + ] + + with self.assertRaises(ValueError) as context: + construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + per_param_sharding={ + "table_0": column_wise( + ranks=[0, 1] + ), # 65 columns cannot be evenly divided by 2 ranks + }, + local_size=2, + world_size=2, + device_type="cuda", + ) + + self.assertIn( + "column dim of 65 cannot be evenly divided across [0, 1]", + str(context.exception), + ) + class ShardingPlanTest(unittest.TestCase): def test_str(self) -> None: