Skip to content

Commit 99028fd

Browse files
Fix INT8 quantization error on Blackwell GPUs (SM100+) (#25935)
Signed-off-by: padg9912 <[email protected]>
1 parent 1244948 commit 99028fd

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
2525
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
2626
int8_func(c, a, b, a_scales, b_scales, bias);
2727
} else {
28-
TORCH_CHECK(false, "Int8 not supported for this architecture");
28+
int32_t version_num = get_sm_version_num();
29+
TORCH_CHECK(
30+
false, "Int8 not supported on SM", version_num,
31+
". Use FP8 quantization instead, or run on older arch (SM < 100).");
2932
}
3033
}
3134
} else {

docs/features/quantization/int8.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ This quantization method is particularly useful for reducing model size while ma
66
Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415).
77

88
!!! note
9-
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell).
9+
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper).
10+
11+
!!! warning
12+
**Blackwell GPU Limitation**: INT8 is not supported on compute capability >= 100 (e.g., RTX 6000 Blackwell).
13+
Use [FP8 quantization](fp8.md) instead, or run on Hopper/Ada/Ampere architectures.
1014

1115
## Prerequisites
1216

0 commit comments

Comments
 (0)