@@ -70,27 +70,45 @@ ccl::event allgatherv_large(const void* send_buf,
7070 dep.wait ();
7171 }
7272 }
73- std::vector<void *> ptrs{ (void *)send_buf, recv_buf }; // index 0 and 1
74- auto [sched, exchange_entry] = do_ipc_exchange (comm, global_stream, ptrs);
7573
76- sycl_ptrs.xelink_ptrs_rd = get_ipc_ptrs<void , MAX_GPUS>(even_comm, 0 , (void *)send_buf, sched);
77- sycl_ptrs.xelink_ptrs_wr = get_ipc_ptrs<void , MAX_GPUS>(even_comm, 1 , recv_buf, sched);
78- // use full vector (>= 8 bytes) if remote buffers and data size are 4 byte aligned
79- use_full_vector = use_full_vector &&
80- all_aligned (sycl_ptrs.xelink_ptrs_rd .data (), even_comm->size (), send_count * dsize, 4 ) &&
81- all_aligned (sycl_ptrs.xelink_ptrs_wr .data (), even_comm->size (), send_count * dsize, 4 );
74+ if (is_arc_card (ccl::ze::get_device_family (global_stream->get_ze_device ()))) {
75+ // only need output buffer
76+ std::vector<void *> ptrs{ recv_buf }; // index 0
77+ auto [sched, exchange_entry] = do_ipc_exchange (comm, global_stream, ptrs);
8278
83- if (pair_comm->size () > 1 ) {
84- assert (pair_comm->size () == MAX_TILES);
85- int peer_pair_rank = pair_comm->rank () ? 0 : 1 ;
86- sycl_ptrs.mdfi_ptr_rd =
87- get_ipc_ptrs<void , MAX_TILES>(pair_comm, 0 , (void *)send_buf, sched)[peer_pair_rank];
88- sycl_ptrs.mdfi_ptr_wr = get_ipc_ptrs<void , MAX_TILES>(pair_comm, 1 , recv_buf, sched)[peer_pair_rank];
89- use_full_vector = use_full_vector && all_aligned (&sycl_ptrs.mdfi_ptr_rd , 1 , send_count * dsize, 4 ) &&
90- all_aligned (&sycl_ptrs.mdfi_ptr_wr , 1 , send_count * dsize, 4 );
79+ std::shared_ptr<ccl_comm> node_comm = comm->get_node_comm ();
80+ sycl_ptrs.node_ptrs_wr = get_ipc_ptrs<void , MAX_NODE_RANKS>(node_comm, 0 , recv_buf, sched);
81+
82+ delete exchange_entry;
83+ delete sched;
84+ }
85+ else {
86+ std::vector<void *> ptrs{ (void *)send_buf, recv_buf }; // index 0 and 1
87+ auto [sched, exchange_entry] = do_ipc_exchange (comm, global_stream, ptrs);
88+
89+ sycl_ptrs.xelink_ptrs_rd = get_ipc_ptrs<void , MAX_GPUS>(even_comm, 0 , (void *)send_buf, sched);
90+ sycl_ptrs.xelink_ptrs_wr = get_ipc_ptrs<void , MAX_GPUS>(even_comm, 1 , recv_buf, sched);
91+ // use full vector (>= 8 bytes) if remote buffers and data size are 4 byte aligned
92+ use_full_vector =
93+ use_full_vector &&
94+ all_aligned (sycl_ptrs.xelink_ptrs_rd .data (), even_comm->size (), send_count * dsize, 4 ) &&
95+ all_aligned (sycl_ptrs.xelink_ptrs_wr .data (), even_comm->size (), send_count * dsize, 4 );
96+
97+ if (pair_comm->size () > 1 ) {
98+ assert (pair_comm->size () == MAX_TILES);
99+ int peer_pair_rank = pair_comm->rank () ? 0 : 1 ;
100+ sycl_ptrs.mdfi_ptr_rd =
101+ get_ipc_ptrs<void , MAX_TILES>(pair_comm, 0 , (void *)send_buf, sched)[peer_pair_rank];
102+ sycl_ptrs.mdfi_ptr_wr =
103+ get_ipc_ptrs<void , MAX_TILES>(pair_comm, 1 , recv_buf, sched)[peer_pair_rank];
104+ use_full_vector = use_full_vector &&
105+ all_aligned (&sycl_ptrs.mdfi_ptr_rd , 1 , send_count * dsize, 4 ) &&
106+ all_aligned (&sycl_ptrs.mdfi_ptr_wr , 1 , send_count * dsize, 4 );
107+ }
108+
109+ delete exchange_entry;
110+ delete sched;
91111 }
92- delete exchange_entry;
93- delete sched;
94112
95113 // coll_init(comm, global_stream);
96114 }
0 commit comments