1919#include " coll/algorithms/utils/sycl_kernels.hpp"
2020#include " coll/algorithms/utils/sycl_coll_base.hpp"
2121
22+ // Kernel name templates for allgatherv_large
23+ template <typename T, int vec_size, int GPUS>
24+ class oneccl_allgatherv_large_ipc {};
25+
26+ template <typename T, int vec_size, int GPUS>
27+ class oneccl_allgatherv_large_main {};
28+
29+ template <typename T, int vec_size, int GPUS>
30+ class oneccl_allgatherv_large_epilogue {};
31+
2232template <typename T>
2333ccl::event allgatherv_large_impl_ipc_ce (sycl::queue& q,
2434 const void * send_buf,
2535 size_t send_count,
2636 void * recv_buf,
2737 const ccl::vector_class<size_t >& recv_counts,
28- size_t orig_count,
29- size_t offset,
38+ const ccl::vector_class<size_t >& offsets,
3039 ccl::datatype dtype,
3140 ccl_comm* comm,
3241 ccl_stream* global_stream,
@@ -80,7 +89,7 @@ ccl::event allgatherv_large_impl_ipc_ce(sycl::queue& q,
8089 int r = (i + even_comm->rank ()) % even_comm->size ();
8190 // TODO: make sure that get_node_rank() (or get_global_rank()) return the ABSOLUTE (i.e. MPI_COMM_WORLD) rank in the node
8291 const int global_rank = even_comm->get_node_rank (r);
83- const size_t offset_bytes = offset + orig_count * global_rank * dsize;
92+ const size_t offset_bytes = !offsets. empty () ? offsets[global_rank] : send_count * global_rank * dsize;
8493
8594 void * src = (char *)sycl_ptrs.xelink_ptrs_rd [r];
8695 void * local = (char *)recv_buf + offset_bytes;
@@ -108,7 +117,7 @@ ccl::event allgatherv_large_impl_ipc_ce(sycl::queue& q,
108117 else {
109118 LOG_DEBUG (" allgatherv large copy engine write" );
110119 const int my_global_rank = node_comm->rank ();
111- const size_t my_offset_bytes = orig_count * my_global_rank * dsize;
120+ const size_t my_offset_bytes = send_count * my_global_rank * dsize;
112121
113122 // TODO: can we delete this barrier
114123 sycl::event barrier_event0 = invoke_barrier (node_comm, q, dep_events, is_cpu_barrier);
@@ -129,7 +138,8 @@ ccl::event allgatherv_large_impl_ipc_ce(sycl::queue& q,
129138 std::vector<sycl::event> cp_events2 (even_comm->size ());
130139 for (int i = 0 ; i < even_comm->size (); i++) {
131140 const int global_rank = even_comm->get_node_rank (i);
132- const size_t offset_bytes = offset + orig_count * global_rank * dsize;
141+ const size_t offset_bytes =
142+ !offsets.empty () ? offsets[global_rank] : send_count * global_rank * dsize;
133143
134144 void * src = (char *)recv_buf + offset_bytes;
135145 void * dst = (char *)sycl_ptrs.mdfi_ptr_wr + offset_bytes;
@@ -154,8 +164,7 @@ ccl::event allgatherv_large_impl_ipc(sycl::queue& q,
154164 size_t send_count,
155165 void * recv_buf,
156166 const ccl::vector_class<size_t >& recv_counts,
157- size_t orig_count,
158- size_t offset,
167+ const ccl::vector_class<size_t >& offsets,
159168 ccl::datatype dtype,
160169 ccl_comm* comm,
161170 ccl_stream* global_stream,
@@ -177,7 +186,7 @@ ccl::event allgatherv_large_impl_ipc(sycl::queue& q,
177186 for (int i = 0 ; i < even_comm->size (); i++) {
178187 // offsets for read_write kernel
179188 const int global_rank = even_comm->get_node_rank (i);
180- const size_t offset_bytes = offset + orig_count * global_rank * dsize;
189+ const size_t offset_bytes = !offsets. empty () ? offsets[global_rank] : send_count * global_rank * dsize;
181190 local_peer_even_ptrs[i] = (char *)sycl_ptrs.xelink_ptrs_rd [i];
182191 local_local_ptrs[i] = (char *)recv_buf + offset_bytes;
183192 local_peer_pair_ptrs[i] = (char *)sycl_ptrs.mdfi_ptr_wr + offset_bytes;
@@ -194,7 +203,7 @@ ccl::event allgatherv_large_impl_ipc(sycl::queue& q,
194203
195204 sycl::event kernel_event = q.submit ([=](sycl::handler& h) {
196205 h.depends_on (barrier_event1);
197- h.parallel_for (
206+ h.parallel_for <oneccl_allgatherv_large_ipc<T, vec_size, N>> (
198207 sycl::nd_range<1 >(kernel_size, work_group_size),
199208 [=](sycl::nd_item<1 > it) [[sycl::reqd_sub_group_size (work_group_size)]] {
200209 read_write<T, N, vec_size>(
@@ -212,8 +221,7 @@ ccl::event allgatherv_large_impl_tmp(sycl::queue& q,
212221 size_t send_count,
213222 void * recv_buf,
214223 const ccl::vector_class<size_t >& recv_counts,
215- size_t orig_count,
216- size_t offset,
224+ const ccl::vector_class<size_t >& offsets,
217225 ccl::datatype dtype,
218226 ccl_comm* comm,
219227 ccl_stream* global_stream,
@@ -289,7 +297,8 @@ ccl::event allgatherv_large_impl_tmp(sycl::queue& q,
289297 // offsets for read_write kernel
290298 int global_rank = comm->is_multi_thread_instance () ? i * pair_comm->size () + pair_comm->rank ()
291299 : even_comm->get_node_rank (i);
292- const size_t offset_bytes = offset + (orig_count * global_rank + chunk_count * nc) * dsize;
300+ const size_t offset_bytes = !offsets.empty () ? offsets[global_rank] + chunk_count * nc * dsize
301+ : (send_count * global_rank + chunk_count * nc) * dsize;
293302 const size_t offset_bytes_tmp = chunk_count * global_rank * dsize;
294303
295304 // xelink and mdfi ptrs are the tmp buffers in the other ranks
@@ -303,7 +312,9 @@ ccl::event allgatherv_large_impl_tmp(sycl::queue& q,
303312 if (global_rank % pair_comm_size == 0 ) {
304313 global_rank_neighbor = global_rank_neighbor + 1 ;
305314 }
306- const size_t offset_bytes_c = offset + (orig_count * global_rank_neighbor + chunk_count * nc) * dsize;
315+ const size_t offset_bytes_c = !offsets.empty ()
316+ ? offsets[global_rank_neighbor] + chunk_count * nc * dsize
317+ : (send_count * global_rank_neighbor + chunk_count * nc) * dsize;
307318 const size_t offset_bytes_c_tmp = chunk_count * global_rank_neighbor * dsize;
308319 recv_buf_dst_ptrs[i] = (char *)recv_buf + offset_bytes_c;
309320 tmp_buf_src_ptrs[i] = (char *)tmp_buf_use + offset_bytes_c_tmp;
@@ -351,7 +362,7 @@ ccl::event allgatherv_large_impl_tmp(sycl::queue& q,
351362
352363 sycl::event kernel_event = q.submit ([=](sycl::handler& h) {
353364 h.depends_on (barrier_event1);
354- h.parallel_for (
365+ h.parallel_for <oneccl_allgatherv_large_main<T, vec_size, N>> (
355366 sycl::nd_range<1 >(kernel_size, work_group_size),
356367 [=](sycl::nd_item<1 > it) [[sycl::reqd_sub_group_size (work_group_size)]] {
357368 read_write<T, N, vec_size>(local_peer_even_ptrs,
@@ -433,7 +444,7 @@ ccl::event allgatherv_large_impl_tmp(sycl::queue& q,
433444 const size_t kernel_threads = data_count / vec_size + data_count % vec_size;
434445 const size_t kernel_size =
435446 ((kernel_threads + work_group_size - 1 ) / work_group_size) * work_group_size;
436- h.parallel_for (
447+ h.parallel_for <oneccl_allgatherv_large_epilogue<T, vec_size, N>> (
437448 sycl::nd_range<1 >(kernel_size, work_group_size),
438449 [=](sycl::nd_item<1 > it) [[sycl::reqd_sub_group_size (work_group_size)]] {
439450 copy_data<T, N, vec_size>(recv_buf_dst_ptrs, tmp_buf_src_ptrs, data_count, it);
@@ -476,8 +487,7 @@ ccl::event allgatherv_large_impl(sycl::queue& q,
476487 size_t send_count,
477488 void * recv_buf,
478489 const ccl::vector_class<size_t >& recv_counts,
479- size_t orig_count,
480- size_t offset,
490+ const ccl::vector_class<size_t >& offsets,
481491 ccl::datatype dtype,
482492 ccl_comm* comm,
483493 ccl_stream* global_stream,
@@ -496,46 +506,16 @@ ccl::event allgatherv_large_impl(sycl::queue& q,
496506 ccl::event e;
497507 // TODO: copy engines currently does not support tmp buf
498508 if (ccl::global_data::env ().sycl_copy_engine ) {
499- e = allgatherv_large_impl_ipc_ce<T>(q,
500- send_buf,
501- send_count,
502- recv_buf,
503- recv_counts,
504- orig_count,
505- offset,
506- dtype,
507- comm,
508- global_stream,
509- sycl_ptrs,
510- deps);
509+ e = allgatherv_large_impl_ipc_ce<T>(
510+ q, send_buf, send_count, recv_buf, recv_counts, offsets, dtype, comm, global_stream, sycl_ptrs, deps);
511511 }
512512 else if (!is_tmp_used) {
513- e = allgatherv_large_impl_ipc<T, N, vec_size_use>(q,
514- send_buf,
515- send_count,
516- recv_buf,
517- recv_counts,
518- orig_count,
519- offset,
520- dtype,
521- comm,
522- global_stream,
523- sycl_ptrs,
524- deps);
513+ e = allgatherv_large_impl_ipc<T, N, vec_size_use>(
514+ q, send_buf, send_count, recv_buf, recv_counts, offsets, dtype, comm, global_stream, sycl_ptrs, deps);
525515 }
526516 else {
527- e = allgatherv_large_impl_tmp<T, N, vec_size_use>(q,
528- send_buf,
529- send_count,
530- recv_buf,
531- recv_counts,
532- orig_count,
533- offset,
534- dtype,
535- comm,
536- global_stream,
537- sycl_ptrs,
538- deps);
517+ e = allgatherv_large_impl_tmp<T, N, vec_size_use>(
518+ q, send_buf, send_count, recv_buf, recv_counts, offsets, dtype, comm, global_stream, sycl_ptrs, deps);
539519 }
540520 return e;
541521}
0 commit comments