Skip to content

Commit 6ef00b0

Browse files
authored
Enable CUDA graph for GPTQ & SqueezeLLM (#2318)
1 parent 9140561 commit 6ef00b0

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

csrc/quantization/gptq/q_gemm.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,8 @@ void gemm_half_q_half_cuda_part
287287

288288
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
289289

290-
kernel<<<gridDim, blockDim>>>
290+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
291+
kernel<<<gridDim, blockDim, 0, stream>>>
291292
(
292293
a,
293294
b_q_weight,
@@ -434,7 +435,8 @@ void reconstruct_exllama
434435
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
435436
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
436437

437-
reconstruct_exllama_kernel<<<gridDim, blockDim>>>
438+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
439+
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
438440
(
439441
b_q_weight,
440442
b_q_perm,
@@ -567,7 +569,8 @@ void gemm_half_q_half_alt
567569
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
568570
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
569571

570-
gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
572+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
573+
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
571574
(
572575
(const half2*) a,
573576
b_q_weight,
@@ -639,7 +642,8 @@ void reconstruct_gptq
639642
blockDim.y = 1;
640643
gridDim.y = DIVIDE(height, 8);
641644
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
642-
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
645+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
646+
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
643647
(
644648
b_q_weight,
645649
b_gptq_scales,
@@ -794,7 +798,8 @@ void shuffle_exllama_weight
794798
gridDim.x = DIVIDE(width, THREADS_X);
795799
gridDim.y = height / 8;
796800

797-
make_sequential_kernel<<<gridDim, blockDim>>>
801+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
802+
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
798803
(
799804
q_weight,
800805
new_qweight,
@@ -813,7 +818,8 @@ void shuffle_exllama_weight
813818
blockDim.y = 1;
814819
gridDim.x = DIVIDE(width, THREADS_X);
815820
gridDim.y = 1;
816-
shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
821+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
822+
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
817823
}
818824

819825
} // namespace gptq

csrc/quantization/squeezellm/quant_cuda_kernel.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,10 @@ void squeezellm_gemm(
200200
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
201201
);
202202
dim3 threads(BLOCKWIDTH);
203+
203204
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
204-
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
205+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
206+
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
205207
#ifndef USE_ROCM
206208
(half2*) vec.data<at::Half>(),
207209
#else

vllm/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,6 @@ def _verify_cuda_graph(self) -> None:
181181
self.max_context_len_to_capture = self.max_model_len
182182
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
183183
self.max_model_len)
184-
if (self.quantization in ["gptq", "squeezellm"]
185-
and not self.enforce_eager):
186-
# Related issue: https://github.com/vllm-project/vllm/issues/2147
187-
logger.warning(f"{self.quantization} does not support CUDA graph "
188-
"yet. Disabling CUDA graph.")
189-
self.enforce_eager = True
190184

191185
def verify_with_parallel_config(
192186
self,

0 commit comments

Comments
 (0)