Skip to content

Commit 2c40571

Browse files
Zijing Liufacebook-github-bot
authored andcommitted
Add HH support for col-wise sharding
Summary: As title, allows `col_wise` to take a a list of customized device types so that for each shard, we are able to specify which device to place it on to. For example, from `([10, 20, 30], "cpu")` to `([10, 20, 30], ["cpu", "cpu", "gpu"])` Reviewed By: gyllstromk Differential Revision: D79600418
1 parent 091ec6b commit 2c40571

File tree

2 files changed

+271
-2
lines changed

2 files changed

+271
-2
lines changed

torchrec/distributed/sharding_plan.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,13 +626,21 @@ def column_wise(
626626
ranks: Optional[List[int]] = None,
627627
size_per_rank: Optional[List[int]] = None,
628628
compute_kernel: Optional[str] = None,
629+
device_types: Optional[List[str]] = None,
629630
) -> ParameterShardingGenerator:
630631
"""
631632
Returns a generator of ParameterShardingPlan for `ShardingType::COLUMN_WISE` for construct_module_sharding_plan.
632-
Table will the sharded column-wise evenly across specified ranks (and can reuse ranks).
633+
Table will be sharded column-wise across specified ranks (and can reuse ranks).
633634
634635
Args:
635-
ranks (List[int]): ranks to place columns
636+
ranks (Optional[List[int]]): Ranks to place columns. Required if size_per_rank is None.
637+
size_per_rank (Optional[List[int]]): List specifying the number of columns per rank.
638+
If provided, the columns will be distributed according to these sizes.
639+
device_types (Optional[List[str]]): List of device types (e.g., "cpu", "cuda") for each shard.
640+
Used to specify different device placements for different shards.
641+
642+
Returns:
643+
ParameterShardingGenerator: A function that generates parameter sharding configuration.
636644
637645
Example::
638646
@@ -652,6 +660,23 @@ def _parameter_sharding_generator(
652660
device_type: str,
653661
sharder: ModuleSharder[nn.Module],
654662
) -> ParameterSharding:
663+
"""
664+
Internal function that generates the parameter sharding configuration.
665+
666+
Args:
667+
param: The parameter tensor to be sharded.
668+
local_size: Number of devices in the local process group.
669+
world_size: Total number of devices across all process groups.
670+
device_type: Type of device (e.g., "cuda", "cpu").
671+
sharder: The module sharder instance.
672+
673+
Returns:
674+
ParameterSharding: The sharding configuration for the parameter.
675+
676+
Raises:
677+
ValueError: If the parameter cannot be evenly divided across ranks or
678+
if the specified sizes cannot fit the tensor.
679+
"""
655680
if size_per_rank is None:
656681
assert ranks is not None
657682

@@ -688,13 +713,22 @@ def _parameter_sharding_generator(
688713
f"Cannot fit tensor of {rows, cols} into sizes_ranks_placements = {size_per_rank}"
689714
)
690715

716+
placements: List[str] = []
717+
if device_types is not None:
718+
index: int = 0
719+
for device in device_types:
720+
placements.append(placement_helper(device, index, index))
721+
if device != "cpu":
722+
index += 1
723+
691724
return _get_parameter_sharding(
692725
param,
693726
ShardingType.COLUMN_WISE.value,
694727
size_offset_ranks,
695728
local_size,
696729
device_type,
697730
sharder,
731+
placements=placements if placements else None,
698732
compute_kernel=compute_kernel,
699733
)
700734

torchrec/distributed/tests/test_sharding_plan.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,241 @@ def test_column_wise(self, data_type: DataType) -> None:
887887
}
888888
self.assertDictEqual(expected, module_sharding_plan)
889889

890+
# pyre-fixme[56]
891+
@given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
892+
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
893+
def test_column_wise_size_per_rank(self, data_type: DataType) -> None:
894+
"""Test column_wise sharding with custom size_per_rank parameter."""
895+
896+
embedding_bag_config = [
897+
EmbeddingBagConfig(
898+
name="table_0",
899+
feature_names=["feature_0"],
900+
embedding_dim=100, # Total columns that will be split as [30, 40, 30]
901+
num_embeddings=1024,
902+
data_type=data_type,
903+
)
904+
]
905+
906+
# Test uneven column distribution: rank 0 gets 30 cols, rank 1 gets 40 cols, rank 2 gets 30 cols
907+
module_sharding_plan = construct_module_sharding_plan(
908+
EmbeddingBagCollection(tables=embedding_bag_config),
909+
per_param_sharding={
910+
"table_0": column_wise(size_per_rank=[30, 40, 30]),
911+
},
912+
local_size=3,
913+
world_size=3,
914+
device_type="cuda",
915+
)
916+
917+
expected = {
918+
"table_0": ParameterSharding(
919+
sharding_type="column_wise",
920+
compute_kernel="dense",
921+
ranks=[0, 1, 2],
922+
sharding_spec=EnumerableShardingSpec(
923+
shards=[
924+
ShardMetadata(
925+
shard_offsets=[0, 0],
926+
shard_sizes=[1024, 30],
927+
placement="rank:0/cuda:0",
928+
),
929+
ShardMetadata(
930+
shard_offsets=[0, 30],
931+
shard_sizes=[1024, 40],
932+
placement="rank:1/cuda:1",
933+
),
934+
ShardMetadata(
935+
shard_offsets=[0, 70],
936+
shard_sizes=[1024, 30],
937+
placement="rank:2/cuda:2",
938+
),
939+
]
940+
),
941+
),
942+
}
943+
self.assertDictEqual(expected, module_sharding_plan)
944+
945+
# pyre-fixme[56]
946+
@given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
947+
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
948+
def test_column_wise_device_types(self, data_type: DataType) -> None:
949+
"""Test column_wise sharding with mixed device types."""
950+
951+
embedding_bag_config = [
952+
EmbeddingBagConfig(
953+
name="table_0",
954+
feature_names=["feature_0"],
955+
embedding_dim=64,
956+
num_embeddings=1024,
957+
data_type=data_type,
958+
)
959+
]
960+
961+
# Test mixed device types: cpu, cuda, cpu, cuda
962+
module_sharding_plan = construct_module_sharding_plan(
963+
EmbeddingBagCollection(tables=embedding_bag_config),
964+
per_param_sharding={
965+
"table_0": column_wise(
966+
ranks=[0, 1, 2, 3],
967+
device_types=["cpu", "cuda", "cpu", "cuda"],
968+
),
969+
},
970+
local_size=4,
971+
world_size=4,
972+
device_type="cuda",
973+
)
974+
975+
expected = {
976+
"table_0": ParameterSharding(
977+
sharding_type="column_wise",
978+
compute_kernel="dense",
979+
ranks=[0, 1, 2, 3],
980+
sharding_spec=EnumerableShardingSpec(
981+
shards=[
982+
ShardMetadata(
983+
shard_offsets=[0, 0],
984+
shard_sizes=[1024, 16],
985+
placement="rank:0/cpu",
986+
),
987+
ShardMetadata(
988+
shard_offsets=[0, 16],
989+
shard_sizes=[1024, 16],
990+
placement="rank:0/cuda:0",
991+
),
992+
ShardMetadata(
993+
shard_offsets=[0, 32],
994+
shard_sizes=[1024, 16],
995+
placement="rank:0/cpu",
996+
),
997+
ShardMetadata(
998+
shard_offsets=[0, 48],
999+
shard_sizes=[1024, 16],
1000+
placement="rank:1/cuda:1",
1001+
),
1002+
]
1003+
),
1004+
),
1005+
}
1006+
self.assertDictEqual(expected, module_sharding_plan)
1007+
1008+
def test_column_wise_size_per_rank_insufficient_columns(self) -> None:
1009+
"""Test that column_wise raises error when size_per_rank doesn't cover all columns."""
1010+
1011+
embedding_bag_config = [
1012+
EmbeddingBagConfig(
1013+
name="table_0",
1014+
feature_names=["feature_0"],
1015+
embedding_dim=100,
1016+
num_embeddings=1024,
1017+
data_type=DataType.FP32,
1018+
)
1019+
]
1020+
1021+
with self.assertRaises(ValueError) as context:
1022+
construct_module_sharding_plan(
1023+
EmbeddingBagCollection(tables=embedding_bag_config),
1024+
per_param_sharding={
1025+
"table_0": column_wise(
1026+
size_per_rank=[30, 40]
1027+
), # Only covers 70/100 columns
1028+
},
1029+
local_size=2,
1030+
world_size=2,
1031+
device_type="cuda",
1032+
)
1033+
1034+
self.assertIn(
1035+
"Cannot fit tensor of (1024, 100) into sizes_ranks_placements = [30, 40]",
1036+
str(context.exception),
1037+
)
1038+
1039+
def test_column_wise_size_per_rank_with_device_types(self) -> None:
1040+
"""Test column_wise sharding with both size_per_rank and device_types parameters."""
1041+
1042+
embedding_bag_config = [
1043+
EmbeddingBagConfig(
1044+
name="table_0",
1045+
feature_names=["feature_0"],
1046+
embedding_dim=80, # Total columns that will be split as [20, 30, 30]
1047+
num_embeddings=512,
1048+
data_type=DataType.FP32,
1049+
)
1050+
]
1051+
1052+
# Test combining custom column sizes with mixed device types
1053+
module_sharding_plan = construct_module_sharding_plan(
1054+
EmbeddingBagCollection(tables=embedding_bag_config),
1055+
per_param_sharding={
1056+
"table_0": column_wise(
1057+
size_per_rank=[20, 30, 30],
1058+
device_types=["cpu", "cuda", "cpu"],
1059+
),
1060+
},
1061+
local_size=3,
1062+
world_size=3,
1063+
device_type="cuda",
1064+
)
1065+
1066+
expected = {
1067+
"table_0": ParameterSharding(
1068+
sharding_type="column_wise",
1069+
compute_kernel="dense",
1070+
ranks=[0, 1, 2],
1071+
sharding_spec=EnumerableShardingSpec(
1072+
shards=[
1073+
ShardMetadata(
1074+
shard_offsets=[0, 0],
1075+
shard_sizes=[512, 20],
1076+
placement="rank:0/cpu",
1077+
),
1078+
ShardMetadata(
1079+
shard_offsets=[0, 20],
1080+
shard_sizes=[512, 30],
1081+
placement="rank:0/cuda:0",
1082+
),
1083+
ShardMetadata(
1084+
shard_offsets=[0, 50],
1085+
shard_sizes=[512, 30],
1086+
placement="rank:0/cpu",
1087+
),
1088+
]
1089+
),
1090+
),
1091+
}
1092+
self.assertDictEqual(expected, module_sharding_plan)
1093+
1094+
def test_column_wise_uneven_division_error(self) -> None:
1095+
"""Test that column_wise raises error when columns can't be evenly divided across ranks."""
1096+
1097+
embedding_bag_config = [
1098+
EmbeddingBagConfig(
1099+
name="table_0",
1100+
feature_names=["feature_0"],
1101+
embedding_dim=65, # Cannot be evenly divided by 2
1102+
num_embeddings=1024,
1103+
data_type=DataType.FP32,
1104+
)
1105+
]
1106+
1107+
with self.assertRaises(ValueError) as context:
1108+
construct_module_sharding_plan(
1109+
EmbeddingBagCollection(tables=embedding_bag_config),
1110+
per_param_sharding={
1111+
"table_0": column_wise(
1112+
ranks=[0, 1]
1113+
), # 65 columns cannot be evenly divided by 2 ranks
1114+
},
1115+
local_size=2,
1116+
world_size=2,
1117+
device_type="cuda",
1118+
)
1119+
1120+
self.assertIn(
1121+
"column dim of 65 cannot be evenly divided across [0, 1]",
1122+
str(context.exception),
1123+
)
1124+
8901125

8911126
class ShardingPlanTest(unittest.TestCase):
8921127
def test_str(self) -> None:

0 commit comments

Comments
 (0)