Skip to content

Commit 448a738

Browse files
committed
TP: Device-global abort flag for cleaner shutdown in case of timeout
1 parent 33a4f7b commit 448a738

File tree

13 files changed

+110
-83
lines changed

13 files changed

+110
-83
lines changed

exllamav3/exllamav3_ext/parallel/all_reduce.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@ void pg_all_reduce_kernel
2525
uint8_t* __restrict__ data_ptr,
2626
uint8_t* __restrict__ shbuf_ptr,
2727
size_t data_size,
28-
size_t shbuf_size
28+
size_t shbuf_size,
29+
uint32_t* abort_flag
2930
)
3031
{
3132
int t = threadIdx.x;
3233
auto grid = cg::this_grid();
3334

34-
__shared__ bool timeout;
3535
__shared__ bool r;
36-
timeout = false;
37-
3836
int dir = blockIdx.x;
3937

4038
int num_ranks = __popc(device_mask);
@@ -79,8 +77,8 @@ void pg_all_reduce_kernel
7977
{
8078
__nanosleep(sleep);
8179
if (sleep < SYNC_MAX_SLEEP) sleep <<= 1;
82-
else timeout = check_timeout(ctx, deadline, "all_reduce");
83-
if (timeout) break;
80+
else *abort_flag = check_timeout(ctx, deadline, "all_reduce");
81+
if (*abort_flag) break;
8482
}
8583
}
8684
__syncthreads();
@@ -168,8 +166,8 @@ void pg_all_reduce_kernel
168166
{
169167
__nanosleep(sleep);
170168
if (sleep < SYNC_MAX_SLEEP) sleep <<= 1;
171-
else timeout = check_timeout(ctx, deadline, "all_reduce (1)");
172-
if (timeout) break;
169+
else *abort_flag = check_timeout(ctx, deadline, "all_reduce (1)");
170+
if (*abort_flag) break;
173171
}
174172
}
175173
}
@@ -204,21 +202,21 @@ void pg_all_reduce_kernel
204202
{
205203
__nanosleep(sleep);
206204
if (sleep < SYNC_MAX_SLEEP) sleep <<= 1;
207-
else timeout = check_timeout(ctx, deadline, "all_reduce (2)");
208-
if (timeout) break;
205+
else *abort_flag = check_timeout(ctx, deadline, "all_reduce (2)");
206+
if (*abort_flag) break;
209207
}
210208
}
211209

212210
// Wait for destination to finish receiving
213211
wait_min_stage(ctx->reduce_stage_consumed + dst_rank, stage_end, deadline);
214212
}
215213

216-
if (timeout) break;
214+
if (*abort_flag) break;
217215
grid.sync();
218216
}
219217

220218
// Finished. Reset counters for next kernel
221-
pg_barrier_inner(ctx, device_mask, this_device, master_device);
219+
pg_barrier_inner(ctx, device_mask, this_device, master_device, abort_flag);
222220

223221
if (t == 0)
224222
{
@@ -237,7 +235,8 @@ void pg_all_reduce
237235
int master_device,
238236
at::Tensor& tensor,
239237
uintptr_t shbuf,
240-
size_t shbuf_size
238+
size_t shbuf_size,
239+
at::Tensor& abort_flag
241240
)
242241
{
243242
const at::cuda::OptionalCUDAGuard device_guard(this_device);
@@ -256,6 +255,7 @@ void pg_all_reduce
256255
int threads = (int) CEIL_DIVIDE(CEIL_DIVIDE(data_size / 16ll, num_ranks), 32ll) * 32ll;
257256
threads = MIN(threads, MAX_NUM_THREADS);
258257

258+
uint32_t* abort_flag_ptr = (uint32_t*) abort_flag.data_ptr();
259259
void* kernelArgs[] =
260260
{
261261
(void*)& ctx,
@@ -266,6 +266,7 @@ void pg_all_reduce
266266
(void*)& shbuf_ptr,
267267
(void*)& data_size,
268268
(void*)& shbuf_size,
269+
(void*)& abort_flag_ptr
269270
};
270271

271272
dim3 block_grid(2);

exllamav3/exllamav3_ext/parallel/all_reduce.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ void pg_all_reduce
1010
int master_device,
1111
at::Tensor& tensor,
1212
uintptr_t shbuf,
13-
size_t shbuf_size
13+
size_t shbuf_size,
14+
at::Tensor& abort_flag
1415
);
1516

1617
void pg_all_reduce_cpu
@@ -23,7 +24,8 @@ void pg_all_reduce_cpu
2324
bool contributor,
2425
uintptr_t shbuf,
2526
size_t shbuf_size,
26-
bool is_master
27+
bool is_master,
28+
at::Tensor& abort_flag
2729
);
2830

2931
void run_cpu_reduce_jobs

exllamav3/exllamav3_ext/parallel/all_reduce_cpu.cu

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,6 @@ void perform_cpu_reduce
296296
}
297297
}
298298

299-
__device__ bool pg_all_reduce_cpu_kernel_timeout;
300-
301299
#define PARCK_MODE_FLOAT 0
302300
#define PARCK_MODE_HALF 1
303301
#define PARCK_MODE_BF16 2
@@ -314,7 +312,8 @@ void pg_all_reduce_cpu_kernel
314312
uint8_t* __restrict__ shbuf_ptr,
315313
size_t data_size,
316314
size_t shbuf_size,
317-
bool contributor
315+
bool contributor,
316+
uint32_t* abort_flag
318317
)
319318
{
320319
// Indexing
@@ -335,8 +334,6 @@ void pg_all_reduce_cpu_kernel
335334

336335
int t = threadIdx.x;
337336
int dir = blockIdx.x;
338-
if (t == 0)
339-
pg_all_reduce_cpu_kernel_timeout = false;
340337
auto grid = cg::this_grid();
341338

342339
// Get device stage
@@ -453,14 +450,12 @@ void pg_all_reduce_cpu_kernel
453450
if (sleep < SYNC_MAX_SLEEP) sleep <<= 1;
454451
else if (check_timeout(ctx, deadline, "pg_all_reduce_cpu_kernel"))
455452
{
456-
DBGI2(ep, stage);
457-
to = true;
453+
*abort_flag = 1;
458454
break;
459455
}
460456
}
461457
}
462458
__syncthreads();
463-
if (to) pg_all_reduce_cpu_kernel_timeout = true;
464459

465460
// Recv float
466461
if constexpr (dtype == PARCK_MODE_FLOAT)
@@ -512,7 +507,7 @@ void pg_all_reduce_cpu_kernel
512507
}
513508

514509
grid.sync();
515-
if (pg_all_reduce_cpu_kernel_timeout) break;
510+
if (*abort_flag) break;
516511
}
517512
}
518513

@@ -526,7 +521,8 @@ void pg_all_reduce_cpu
526521
bool contributor,
527522
uintptr_t shbuf,
528523
size_t shbuf_size,
529-
bool is_master
524+
bool is_master,
525+
at::Tensor& abort_flag
530526
)
531527
{
532528
const at::cuda::OptionalCUDAGuard device_guard(this_device);
@@ -543,6 +539,7 @@ void pg_all_reduce_cpu
543539

544540
TORCH_CHECK(cpu_data_size % 16 == 0, "data_size must be multiple of 16");
545541

542+
uint32_t* abort_flag_ptr = (uint32_t*) abort_flag.data_ptr();
546543
void* kernelArgs[] =
547544
{
548545
(void*)& ctx,
@@ -553,7 +550,8 @@ void pg_all_reduce_cpu
553550
(void*)& shbuf_ptr,
554551
(void*)& device_data_size,
555552
(void*)& shbuf_size,
556-
(void*)& contributor
553+
(void*)& contributor,
554+
(void*)& abort_flag_ptr
557555
};
558556

559557
dim3 block_grid(2);

exllamav3/exllamav3_ext/parallel/barrier.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@ __global__ void pg_barrier_kernel
1414
PGContext* __restrict__ ctx,
1515
uint32_t device_mask,
1616
int this_device,
17-
int coordinator_device
17+
int coordinator_device,
18+
uint32_t* abort_flag
1819
)
1920
{
20-
pg_barrier_inner(ctx, device_mask, this_device, coordinator_device);
21+
pg_barrier_inner(ctx, device_mask, this_device, coordinator_device, abort_flag);
2122
}
2223

23-
2424
void pg_barrier
2525
(
2626
uintptr_t ctx,
2727
std::vector<uintptr_t> devices,
28-
int this_device
28+
int this_device,
29+
at::Tensor& abort_flag
2930
)
3031
{
3132
const at::cuda::OptionalCUDAGuard device_guard(this_device);
@@ -40,7 +41,8 @@ void pg_barrier
4041
(PGContext*) ctx, // Shared, pinned
4142
device_mask,
4243
this_device,
43-
devices[0]
44+
devices[0],
45+
(uint32_t*) abort_flag.data_ptr()
4446
);
4547
cuda_check(cudaPeekAtLastError());
4648
}

exllamav3/exllamav3_ext/parallel/barrier.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ void pg_barrier
77
(
88
uintptr_t ctx,
99
std::vector<uintptr_t> devices,
10-
int this_device
10+
int this_device,
11+
at::Tensor& abort_flag
1112
);

exllamav3/exllamav3_ext/parallel/barrier_inner.cuh

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@ __device__ __forceinline__ void pg_barrier_inner
44
PGContext* __restrict__ ctx,
55
uint32_t device_mask,
66
int this_device,
7-
int coordinator_device
7+
int coordinator_device,
8+
uint32_t* abort_flag
89
)
910
{
10-
bool timeout = false;
11-
1211
if (!blockIdx.x && !blockIdx.y && !blockIdx.z && !threadIdx.x && !threadIdx.y && !threadIdx.z)
1312
{
1413
uint32_t* epoch_ptr = &ctx->barrier_epoch;
@@ -47,8 +46,8 @@ __device__ __forceinline__ void pg_barrier_inner
4746
{
4847
__nanosleep(sleep);
4948
if (sleep < SYNC_MAX_SLEEP) sleep <<= 1;
50-
else timeout = check_timeout(ctx, deadline, "barrier");
51-
if (timeout) break;
49+
else *abort_flag = check_timeout(ctx, deadline, "barrier");
50+
if (*abort_flag) break;
5251
}
5352
else sleep = SYNC_MIN_SLEEP;
5453
}
@@ -66,8 +65,8 @@ __device__ __forceinline__ void pg_barrier_inner
6665
{
6766
__nanosleep(sleep);
6867
if (sleep < SYNC_MAX_SLEEP) sleep <<= 1;
69-
else timeout = check_timeout(ctx, deadline, "barrier");
70-
if (timeout) break;
68+
else *abort_flag = check_timeout(ctx, deadline, "barrier");
69+
if (*abort_flag) break;
7170
}
7271
}
7372
}

0 commit comments

Comments
 (0)