Skip to content

Add HH support for col-wise sharding #3269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions torchrec/distributed/sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,21 @@ def column_wise(
ranks: Optional[List[int]] = None,
size_per_rank: Optional[List[int]] = None,
compute_kernel: Optional[str] = None,
device_types: Optional[List[str]] = None,
) -> ParameterShardingGenerator:
"""
Returns a generator of ParameterShardingPlan for `ShardingType::COLUMN_WISE` for construct_module_sharding_plan.
Table will the sharded column-wise evenly across specified ranks (and can reuse ranks).
Table will be sharded column-wise across specified ranks (and can reuse ranks).

Args:
ranks (List[int]): ranks to place columns
ranks (Optional[List[int]]): Ranks to place columns. Required if size_per_rank is None.
size_per_rank (Optional[List[int]]): List specifying the number of columns per rank.
If provided, the columns will be distributed according to these sizes.
device_types (Optional[List[str]]): List of device types (e.g., "cpu", "cuda") for each shard.
Used to specify different device placements for different shards.

Returns:
ParameterShardingGenerator: A function that generates parameter sharding configuration.

Example::

Expand All @@ -652,6 +660,23 @@ def _parameter_sharding_generator(
device_type: str,
sharder: ModuleSharder[nn.Module],
) -> ParameterSharding:
"""
Internal function that generates the parameter sharding configuration.

Args:
param: The parameter tensor to be sharded.
local_size: Number of devices in the local process group.
world_size: Total number of devices across all process groups.
device_type: Type of device (e.g., "cuda", "cpu").
sharder: The module sharder instance.

Returns:
ParameterSharding: The sharding configuration for the parameter.

Raises:
ValueError: If the parameter cannot be evenly divided across ranks or
if the specified sizes cannot fit the tensor.
"""
if size_per_rank is None:
assert ranks is not None

Expand Down Expand Up @@ -688,13 +713,25 @@ def _parameter_sharding_generator(
f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {size_per_rank}"
)

placements: List[str] = []
if device_types is not None:
assert len(device_types) == len(
size_offset_ranks
), "device_types must be the same length as ranks"
index: int = 0
for device in device_types:
placements.append(placement_helper(device, index, index))
if device != "cpu":
index += 1

return _get_parameter_sharding(
param,
ShardingType.COLUMN_WISE.value,
size_offset_ranks,
local_size,
device_type,
sharder,
placements=placements if placements else None,
compute_kernel=compute_kernel,
)

Expand Down
235 changes: 235 additions & 0 deletions torchrec/distributed/tests/test_sharding_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,241 @@ def test_column_wise(self, data_type: DataType) -> None:
}
self.assertDictEqual(expected, module_sharding_plan)

# pyre-fixme[56]
@given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
def test_column_wise_size_per_rank(self, data_type: DataType) -> None:
"""Test column_wise sharding with custom size_per_rank parameter."""

embedding_bag_config = [
EmbeddingBagConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=100, # Total columns that will be split as [30, 40, 30]
num_embeddings=1024,
data_type=data_type,
)
]

# Test uneven column distribution: rank 0 gets 30 cols, rank 1 gets 40 cols, rank 2 gets 30 cols
module_sharding_plan = construct_module_sharding_plan(
EmbeddingBagCollection(tables=embedding_bag_config),
per_param_sharding={
"table_0": column_wise(size_per_rank=[30, 40, 30]),
},
local_size=3,
world_size=3,
device_type="cuda",
)

expected = {
"table_0": ParameterSharding(
sharding_type="column_wise",
compute_kernel="dense",
ranks=[0, 1, 2],
sharding_spec=EnumerableShardingSpec(
shards=[
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[1024, 30],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[0, 30],
shard_sizes=[1024, 40],
placement="rank:1/cuda:1",
),
ShardMetadata(
shard_offsets=[0, 70],
shard_sizes=[1024, 30],
placement="rank:2/cuda:2",
),
]
),
),
}
self.assertDictEqual(expected, module_sharding_plan)

# pyre-fixme[56]
@given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
def test_column_wise_device_types(self, data_type: DataType) -> None:
"""Test column_wise sharding with mixed device types."""

embedding_bag_config = [
EmbeddingBagConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=64,
num_embeddings=1024,
data_type=data_type,
)
]

# Test mixed device types: cpu, cuda, cpu, cuda
module_sharding_plan = construct_module_sharding_plan(
EmbeddingBagCollection(tables=embedding_bag_config),
per_param_sharding={
"table_0": column_wise(
ranks=[0, 1, 2, 3],
device_types=["cpu", "cuda", "cpu", "cuda"],
),
},
local_size=4,
world_size=4,
device_type="cuda",
)

expected = {
"table_0": ParameterSharding(
sharding_type="column_wise",
compute_kernel="dense",
ranks=[0, 1, 2, 3],
sharding_spec=EnumerableShardingSpec(
shards=[
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[1024, 16],
placement="rank:0/cpu",
),
ShardMetadata(
shard_offsets=[0, 16],
shard_sizes=[1024, 16],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[0, 32],
shard_sizes=[1024, 16],
placement="rank:0/cpu",
),
ShardMetadata(
shard_offsets=[0, 48],
shard_sizes=[1024, 16],
placement="rank:1/cuda:1",
),
]
),
),
}
self.assertDictEqual(expected, module_sharding_plan)

def test_column_wise_size_per_rank_insufficient_columns(self) -> None:
"""Test that column_wise raises error when size_per_rank doesn't cover all columns."""

embedding_bag_config = [
EmbeddingBagConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=100,
num_embeddings=1024,
data_type=DataType.FP32,
)
]

with self.assertRaises(ValueError) as context:
construct_module_sharding_plan(
EmbeddingBagCollection(tables=embedding_bag_config),
per_param_sharding={
"table_0": column_wise(
size_per_rank=[30, 40]
), # Only covers 70/100 columns
},
local_size=2,
world_size=2,
device_type="cuda",
)

self.assertIn(
"Cannot fit tensor of (1024, 100) into sizes_ranks_placements = [30, 40]",
str(context.exception),
)

def test_column_wise_size_per_rank_with_device_types(self) -> None:
"""Test column_wise sharding with both size_per_rank and device_types parameters."""

embedding_bag_config = [
EmbeddingBagConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=80, # Total columns that will be split as [20, 30, 30]
num_embeddings=512,
data_type=DataType.FP32,
)
]

# Test combining custom column sizes with mixed device types
module_sharding_plan = construct_module_sharding_plan(
EmbeddingBagCollection(tables=embedding_bag_config),
per_param_sharding={
"table_0": column_wise(
size_per_rank=[20, 30, 30],
device_types=["cpu", "cuda", "cpu"],
),
},
local_size=3,
world_size=3,
device_type="cuda",
)

expected = {
"table_0": ParameterSharding(
sharding_type="column_wise",
compute_kernel="dense",
ranks=[0, 1, 2],
sharding_spec=EnumerableShardingSpec(
shards=[
ShardMetadata(
shard_offsets=[0, 0],
shard_sizes=[512, 20],
placement="rank:0/cpu",
),
ShardMetadata(
shard_offsets=[0, 20],
shard_sizes=[512, 30],
placement="rank:0/cuda:0",
),
ShardMetadata(
shard_offsets=[0, 50],
shard_sizes=[512, 30],
placement="rank:0/cpu",
),
]
),
),
}
self.assertDictEqual(expected, module_sharding_plan)

def test_column_wise_uneven_division_error(self) -> None:
"""Test that column_wise raises error when columns can't be evenly divided across ranks."""

embedding_bag_config = [
EmbeddingBagConfig(
name="table_0",
feature_names=["feature_0"],
embedding_dim=65, # Cannot be evenly divided by 2
num_embeddings=1024,
data_type=DataType.FP32,
)
]

with self.assertRaises(ValueError) as context:
construct_module_sharding_plan(
EmbeddingBagCollection(tables=embedding_bag_config),
per_param_sharding={
"table_0": column_wise(
ranks=[0, 1]
), # 65 columns cannot be evenly divided by 2 ranks
},
local_size=2,
world_size=2,
device_type="cuda",
)

self.assertIn(
"column dim of 65 cannot be evenly divided across [0, 1]",
str(context.exception),
)


class ShardingPlanTest(unittest.TestCase):
def test_str(self) -> None:
Expand Down
Loading