@@ -287,7 +287,8 @@ void gemm_half_q_half_cuda_part
287
287
288
288
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel (true , m_count);
289
289
290
- kernel<<<gridDim , blockDim >>>
290
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
291
+ kernel<<<gridDim , blockDim , 0 , stream>>>
291
292
(
292
293
a,
293
294
b_q_weight,
@@ -434,7 +435,8 @@ void reconstruct_exllama
434
435
gridDim .y = DIVIDE (height, BLOCK_KN_SIZE);
435
436
gridDim .x = DIVIDE (width, BLOCK_KN_SIZE);
436
437
437
- reconstruct_exllama_kernel<<<gridDim , blockDim >>>
438
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
439
+ reconstruct_exllama_kernel<<<gridDim , blockDim , 0 , stream>>>
438
440
(
439
441
b_q_weight,
440
442
b_q_perm,
@@ -567,7 +569,8 @@ void gemm_half_q_half_alt
567
569
gridDim .y = DIVIDE (size_m, BLOCK_M_SIZE_MAX);
568
570
gridDim .z = DIVIDE (size_k, BLOCK_KN_SIZE);
569
571
570
- gemm_half_q_half_alt_kernel<<<gridDim , blockDim >>>
572
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
573
+ gemm_half_q_half_alt_kernel<<<gridDim , blockDim , 0 , stream>>>
571
574
(
572
575
(const half2*) a,
573
576
b_q_weight,
@@ -639,7 +642,8 @@ void reconstruct_gptq
639
642
blockDim .y = 1 ;
640
643
gridDim .y = DIVIDE (height, 8 );
641
644
gridDim .x = DIVIDE (width, BLOCK_KN_SIZE);
642
- reconstruct_gptq_kernel<<<gridDim , blockDim >>>
645
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
646
+ reconstruct_gptq_kernel<<<gridDim , blockDim , 0 , stream>>>
643
647
(
644
648
b_q_weight,
645
649
b_gptq_scales,
@@ -794,7 +798,8 @@ void shuffle_exllama_weight
794
798
gridDim .x = DIVIDE (width, THREADS_X);
795
799
gridDim .y = height / 8 ;
796
800
797
- make_sequential_kernel<<<gridDim , blockDim >>>
801
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
802
+ make_sequential_kernel<<<gridDim , blockDim , 0 , stream>>>
798
803
(
799
804
q_weight,
800
805
new_qweight,
@@ -813,7 +818,8 @@ void shuffle_exllama_weight
813
818
blockDim .y = 1 ;
814
819
gridDim .x = DIVIDE (width, THREADS_X);
815
820
gridDim .y = 1 ;
816
- shuffle_kernel<<<gridDim , blockDim >>> (q_weight, height, width);
821
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
822
+ shuffle_kernel<<<gridDim , blockDim , 0 , stream>>> (q_weight, height, width);
817
823
}
818
824
819
825
} // namespace gptq
0 commit comments