Skip to content

Commit 29678cd

Browse files
authored
Minor fix on AWQ kernel launch (#1356)
1 parent d0740df commit 29678cd

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

csrc/quantization/awq/gemm_kernels.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,25 +534,26 @@ torch::Tensor awq_gemm(
534534
if (num_out_channels % group_size != 0)
535535
throw std::invalid_argument("OC is not multiple of Group size");
536536

537+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
537538
if (num_out_channels % 128 == 0)
538539
{
539540
int j_factors1 = num_out_channels / 128 / 1;
540541
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
541542
// threadIdx.x: 32
542543
// threadIdx.y: i_factors[2] * j_factors[2]
543544
dim3 threads_per_block(32, 2);
544-
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
545+
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
545546
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
546547
}
547548
else if (num_out_channels % 64 == 0)
548549
{
549-
int j_factors1 = num_out_channels / 64 / 1;
550+
int j_factors1 = num_out_channels / 64 / 1;
550551
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
551552

552553
// threadIdx.x: 32
553554
// threadIdx.y: i_factors[2] * j_factors[2]
554555
dim3 threads_per_block(32, 2);
555-
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
556+
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
556557
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
557558
}
558559
return _out_feats.sum(0);

0 commit comments

Comments
 (0)