@@ -25,16 +25,14 @@ void pg_all_reduce_kernel
2525 uint8_t * __restrict__ data_ptr,
2626 uint8_t * __restrict__ shbuf_ptr,
2727 size_t data_size,
28- size_t shbuf_size
28+ size_t shbuf_size,
29+ uint32_t * abort_flag
2930)
3031{
3132 int t = threadIdx .x ;
3233 auto grid = cg::this_grid ();
3334
34- __shared__ bool timeout;
3535 __shared__ bool r;
36- timeout = false ;
37-
3836 int dir = blockIdx .x ;
3937
4038 int num_ranks = __popc (device_mask);
@@ -79,8 +77,8 @@ void pg_all_reduce_kernel
7977 {
8078 __nanosleep (sleep);
8179 if (sleep < SYNC_MAX_SLEEP) sleep <<= 1 ;
82- else timeout = check_timeout (ctx, deadline, " all_reduce" );
83- if (timeout ) break ;
80+ else *abort_flag = check_timeout (ctx, deadline, " all_reduce" );
81+ if (*abort_flag ) break ;
8482 }
8583 }
8684 __syncthreads ();
@@ -168,8 +166,8 @@ void pg_all_reduce_kernel
168166 {
169167 __nanosleep (sleep);
170168 if (sleep < SYNC_MAX_SLEEP) sleep <<= 1 ;
171- else timeout = check_timeout (ctx, deadline, " all_reduce (1)" );
172- if (timeout ) break ;
169+ else *abort_flag = check_timeout (ctx, deadline, " all_reduce (1)" );
170+ if (*abort_flag ) break ;
173171 }
174172 }
175173 }
@@ -204,21 +202,21 @@ void pg_all_reduce_kernel
204202 {
205203 __nanosleep (sleep);
206204 if (sleep < SYNC_MAX_SLEEP) sleep <<= 1 ;
207- else timeout = check_timeout (ctx, deadline, " all_reduce (2)" );
208- if (timeout ) break ;
205+ else *abort_flag = check_timeout (ctx, deadline, " all_reduce (2)" );
206+ if (*abort_flag ) break ;
209207 }
210208 }
211209
212210 // Wait for destination to finish receiving
213211 wait_min_stage (ctx->reduce_stage_consumed + dst_rank, stage_end, deadline);
214212 }
215213
216- if (timeout ) break ;
214+ if (*abort_flag ) break ;
217215 grid.sync ();
218216 }
219217
220218 // Finished. Reset counters for next kernel
221- pg_barrier_inner (ctx, device_mask, this_device, master_device);
219+ pg_barrier_inner (ctx, device_mask, this_device, master_device, abort_flag );
222220
223221 if (t == 0 )
224222 {
@@ -237,7 +235,8 @@ void pg_all_reduce
237235 int master_device,
238236 at::Tensor& tensor,
239237 uintptr_t shbuf,
240- size_t shbuf_size
238+ size_t shbuf_size,
239+ at::Tensor& abort_flag
241240)
242241{
243242 const at::cuda::OptionalCUDAGuard device_guard (this_device);
@@ -256,6 +255,7 @@ void pg_all_reduce
256255 int threads = (int ) CEIL_DIVIDE (CEIL_DIVIDE (data_size / 16ll , num_ranks), 32ll ) * 32ll ;
257256 threads = MIN (threads, MAX_NUM_THREADS);
258257
258+ uint32_t * abort_flag_ptr = (uint32_t *) abort_flag.data_ptr ();
259259 void * kernelArgs[] =
260260 {
261261 (void *)& ctx,
@@ -266,6 +266,7 @@ void pg_all_reduce
266266 (void *)& shbuf_ptr,
267267 (void *)& data_size,
268268 (void *)& shbuf_size,
269+ (void *)& abort_flag_ptr
269270 };
270271
271272 dim3 block_grid (2 );
0 commit comments