@@ -76,6 +76,34 @@ class SharderType(Enum):
76
76
EMBEDDING_COLLECTION = "embedding_collection"
77
77
78
78
79
+ def _gather_all_tensors (
80
+ local_tensor : torch .Tensor ,
81
+ world_size : int ,
82
+ pg : Optional [dist .ProcessGroup ] = None ,
83
+ ) -> List [torch .Tensor ]:
84
+ """
85
+ Gathers tensors from all processes in a distributed group.
86
+
87
+ This function collects tensors from all processes in the specified
88
+ process group and returns a list of tensors, where each tensor
89
+ corresponds to the data from one process.
90
+
91
+ Args:
92
+ local_tensor (torch.Tensor): The tensor to be gathered from the local process.
93
+ world_size (int): The number of processes in the distributed group.
94
+ pg (Optional[ProcessGroup]): The process group to use for communication.
95
+ If not provided, a default ProcessGroup will be created.
96
+
97
+ Returns:
98
+ List[torch.Tensor]: A list of tensors gathered from all processes.
99
+ """
100
+ all_local_tensors : List [torch .Tensor ] = []
101
+ for _ in range (world_size ):
102
+ all_local_tensors .append (torch .empty_like (local_tensor ))
103
+ dist .all_gather (all_local_tensors , local_tensor , pg )
104
+ return all_local_tensors
105
+
106
+
79
107
def create_test_sharder (
80
108
sharder_type : str ,
81
109
sharding_type : str ,
@@ -558,14 +586,10 @@ def dynamic_sharding_test(
558
586
)
559
587
560
588
# TODO: support non-sharded forward with zero batch size KJT
561
- all_local_pred_m1 = []
562
- for _ in range (world_size ):
563
- all_local_pred_m1 .append (torch .empty_like (local_m1_pred ))
564
- dist .all_gather (all_local_pred_m1 , local_m1_pred , group = ctx .pg )
565
- all_local_pred_m2 = []
566
- for _ in range (world_size ):
567
- all_local_pred_m2 .append (torch .empty_like (local_m2_pred ))
568
- dist .all_gather (all_local_pred_m2 , local_m2_pred , group = ctx .pg )
589
+
590
+ all_local_pred_m1 = _gather_all_tensors (local_m1_pred , world_size , ctx .pg )
591
+
592
+ all_local_pred_m2 = _gather_all_tensors (local_m2_pred , world_size , ctx .pg )
569
593
570
594
# Compare predictions of sharded vs unsharded models.
571
595
if qcomms_config is None :
@@ -895,10 +919,7 @@ def _custom_hook(input: List[torch.Tensor]) -> None:
895
919
896
920
# TODO: support non-sharded forward with zero batch size KJT
897
921
if not allow_zero_batch_size :
898
- all_local_pred = []
899
- for _ in range (world_size ):
900
- all_local_pred .append (torch .empty_like (local_pred ))
901
- dist .all_gather (all_local_pred , local_pred , group = pg )
922
+ all_local_pred = _gather_all_tensors (local_pred , world_size , pg )
902
923
903
924
# Run second training step of the unsharded model.
904
925
assert optim == EmbOptimType .EXACT_SGD
0 commit comments