@@ -77,12 +77,12 @@ torch::Tensor gather(const torch::Tensor& input,
7777 if (!process_group) {
7878 return input;
7979 }
80- const auto world_size = process_group->world_size ();
80+ const int32_t world_size = process_group->world_size ();
8181 if (world_size == 1 ) {
8282 return input;
8383 }
8484
85- const auto rank = process_group->rank ();
85+ const int32_t rank = process_group->rank ();
8686 std::vector<torch::Tensor> tensors (world_size);
8787 for (int64_t i = 0 ; i < world_size; ++i) {
8888 tensors[i] = torch::empty_like (input);
@@ -98,8 +98,8 @@ torch::Tensor gather(const torch::Tensor& input,
9898 if (!process_group) {
9999 return input;
100100 }
101- const auto world_size = process_group->world_size ();
102- const auto rank = process_group->rank ();
101+ const int32_t world_size = process_group->world_size ();
102+ const int32_t rank = process_group->rank ();
103103 if (world_size == 1 ) {
104104 return input;
105105 }
@@ -131,11 +131,42 @@ torch::Tensor gather(const torch::Tensor& input,
131131 gathered_input, max_num_tokens, token_num_list);
132132}
133133
134+ torch::Tensor all_gather_interleaved (const torch::Tensor& input,
135+ ProcessGroup* process_group) {
136+ if (!process_group) {
137+ return input;
138+ }
139+ const int32_t world_size = process_group->world_size ();
140+ const int32_t rank = process_group->rank ();
141+ if (world_size == 1 ) {
142+ return input;
143+ }
144+
145+ std::vector<torch::Tensor> gathered_tensors (world_size);
146+ for (int64_t i = 0 ; i < world_size; ++i) {
147+ gathered_tensors[i] = torch::empty_like (input);
148+ }
149+ process_group->allgather (input, gathered_tensors);
150+
151+ int32_t dim = -1 ;
152+ size_t num_chunks = 3 ;
153+ std::vector<torch::Tensor> ordered_tensors;
154+ int64_t shard_size = input.size (dim) / num_chunks;
155+ for (size_t i = 0 ; i < num_chunks; ++i) {
156+ for (size_t j = 0 ; j < world_size; ++j) {
157+ auto shard_tensor =
158+ gathered_tensors[j].slice (dim, shard_size * i, shard_size * (i + 1 ));
159+ ordered_tensors.push_back (shard_tensor);
160+ }
161+ }
162+ return torch::cat (ordered_tensors, dim).contiguous ();
163+ }
164+
134165torch::Tensor reduce (torch::Tensor& input, ProcessGroup* process_group) {
135166 if (!process_group) {
136167 return input;
137168 }
138- const auto world_size = process_group->world_size ();
169+ const int32_t world_size = process_group->world_size ();
139170 if (world_size == 1 ) {
140171 return input;
141172 }
@@ -149,20 +180,20 @@ torch::Tensor scatter(torch::Tensor input,
149180 if (!process_group) {
150181 return input;
151182 }
152- const auto world_size = process_group->world_size ();
183+ const int32_t world_size = process_group->world_size ();
153184 if (world_size == 1 ) {
154185 return input;
155186 }
156187
157188 // get the size for last dimension
158- const auto dim_size = input.size (dim);
189+ const int32_t dim_size = input.size (dim);
159190 CHECK (dim_size % world_size == 0 )
160191 << " dim_size " << dim_size << " cannot be divided by world_size "
161192 << world_size;
162193
163194 // torch::split does not create contiguous tensors by default.
164195 const auto tensor_list = input.split (dim_size / world_size, dim);
165- const auto rank = process_group->rank ();
196+ const int32_t rank = process_group->rank ();
166197 return tensor_list[rank];
167198}
168199
0 commit comments