@@ -534,25 +534,26 @@ torch::Tensor awq_gemm(
534
534
if (num_out_channels % group_size != 0 )
535
535
throw std::invalid_argument (" OC is not multiple of Group size" );
536
536
537
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
537
538
if (num_out_channels % 128 == 0 )
538
539
{
539
540
int j_factors1 = num_out_channels / 128 / 1 ;
540
541
dim3 num_blocks ((num_out_feats + 16 - 1 ) / 16 * j_factors1 * split_k_iters);
541
542
// threadIdx.x: 32
542
543
// threadIdx.y: i_factors[2] * j_factors[2]
543
544
dim3 threads_per_block (32 , 2 );
544
- vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>> (
545
+ vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0 , stream >>> (
545
546
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
546
547
}
547
548
else if (num_out_channels % 64 == 0 )
548
549
{
549
- int j_factors1 = num_out_channels / 64 / 1 ;
550
+ int j_factors1 = num_out_channels / 64 / 1 ;
550
551
dim3 num_blocks (1 * (num_out_feats + 16 - 1 ) / 16 * j_factors1 * split_k_iters);
551
552
552
553
// threadIdx.x: 32
553
554
// threadIdx.y: i_factors[2] * j_factors[2]
554
555
dim3 threads_per_block (32 , 2 );
555
- vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>> (
556
+ vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0 , stream >>> (
556
557
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
557
558
}
558
559
return _out_feats.sum (0 );
0 commit comments