Skip to content

Commit 140e979

Browse files
isururanawakafacebook-github-bot
authored andcommitted
helper for collectives to pass tensors in sharding tests (#3161)
Summary: Pull Request resolved: #3161 -Create gather_all_tensors function -Replace dist.all_gather operation Reviewed By: aporialiao Differential Revision: D77819104 fbshipit-source-id: f82d19c5aae6ae955b7e100f17dedd206ec1a0b4
1 parent a6b2e8f commit 140e979

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,34 @@ class SharderType(Enum):
7676
EMBEDDING_COLLECTION = "embedding_collection"
7777

7878

79+
def _gather_all_tensors(
80+
local_tensor: torch.Tensor,
81+
world_size: int,
82+
pg: Optional[dist.ProcessGroup] = None,
83+
) -> List[torch.Tensor]:
84+
"""
85+
Gathers tensors from all processes in a distributed group.
86+
87+
This function collects tensors from all processes in the specified
88+
process group and returns a list of tensors, where each tensor
89+
corresponds to the data from one process.
90+
91+
Args:
92+
local_tensor (torch.Tensor): The tensor to be gathered from the local process.
93+
world_size (int): The number of processes in the distributed group.
94+
pg (Optional[ProcessGroup]): The process group to use for communication.
95+
If not provided, a default ProcessGroup will be created.
96+
97+
Returns:
98+
List[torch.Tensor]: A list of tensors gathered from all processes.
99+
"""
100+
all_local_tensors: List[torch.Tensor] = []
101+
for _ in range(world_size):
102+
all_local_tensors.append(torch.empty_like(local_tensor))
103+
dist.all_gather(all_local_tensors, local_tensor, pg)
104+
return all_local_tensors
105+
106+
79107
def create_test_sharder(
80108
sharder_type: str,
81109
sharding_type: str,
@@ -558,14 +586,10 @@ def dynamic_sharding_test(
558586
)
559587

560588
# TODO: support non-sharded forward with zero batch size KJT
561-
all_local_pred_m1 = []
562-
for _ in range(world_size):
563-
all_local_pred_m1.append(torch.empty_like(local_m1_pred))
564-
dist.all_gather(all_local_pred_m1, local_m1_pred, group=ctx.pg)
565-
all_local_pred_m2 = []
566-
for _ in range(world_size):
567-
all_local_pred_m2.append(torch.empty_like(local_m2_pred))
568-
dist.all_gather(all_local_pred_m2, local_m2_pred, group=ctx.pg)
589+
590+
all_local_pred_m1 = _gather_all_tensors(local_m1_pred, world_size, ctx.pg)
591+
592+
all_local_pred_m2 = _gather_all_tensors(local_m2_pred, world_size, ctx.pg)
569593

570594
# Compare predictions of sharded vs unsharded models.
571595
if qcomms_config is None:
@@ -895,10 +919,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
895919

896920
# TODO: support non-sharded forward with zero batch size KJT
897921
if not allow_zero_batch_size:
898-
all_local_pred = []
899-
for _ in range(world_size):
900-
all_local_pred.append(torch.empty_like(local_pred))
901-
dist.all_gather(all_local_pred, local_pred, group=pg)
922+
all_local_pred = _gather_all_tensors(local_pred, world_size, pg)
902923

903924
# Run second training step of the unsharded model.
904925
assert optim == EmbOptimType.EXACT_SGD

0 commit comments

Comments
 (0)