Skip to content

Commit 774d0c0

Browse files
authored
[Perf] Cuda Kernel for Per Token Group Quant (#21083)
Signed-off-by: yewentao256 <[email protected]>
1 parent 2c8db17 commit 774d0c0

File tree

6 files changed

+285
-4
lines changed

6 files changed

+285
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ set(VLLM_EXT_SRC
245245
"csrc/quantization/gptq/q_gemm.cu"
246246
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
247247
"csrc/quantization/fp8/common.cu"
248+
"csrc/quantization/fp8/per_token_group_quant.cu"
248249
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
249250
"csrc/quantization/gguf/gguf_kernel.cu"
250251
"csrc/quantization/activation_kernels.cu"

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
297297
torch::Tensor& scales,
298298
std::optional<torch::Tensor> const& azp);
299299

300+
void per_token_group_quant_fp8(const torch::Tensor& input,
301+
torch::Tensor& output_q, torch::Tensor& output_s,
302+
int64_t group_size, double eps, double fp8_min,
303+
double fp8_max, bool scale_ue8m0);
304+
300305
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
301306
torch::Tensor b_gptq_qzeros,
302307
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <c10/util/Float8_e4m3fn.h>
3+
4+
#include <cmath>
5+
6+
#include <cuda_fp16.h>
7+
#include <cuda_bf16.h>
8+
9+
#include <torch/all.h>
10+
11+
#include "../vectorization.cuh"
12+
#include "../vectorization_utils.cuh"
13+
#include "../../dispatch_utils.h"
14+
15+
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
16+
unsigned mask = 0xffff;
17+
18+
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
19+
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
20+
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
21+
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
22+
return val;
23+
}
24+
25+
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
26+
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
27+
__global__ void per_token_group_quant_8bit_kernel(
28+
const T* __restrict__ input, void* __restrict__ output_q,
29+
scale_packed_t* __restrict__ output_s, const int group_size,
30+
const int num_groups, const int groups_per_block, const float eps,
31+
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
32+
const int scale_stride = 0) {
33+
const int threads_per_group = 16;
34+
const int64_t local_group_id = threadIdx.x / threads_per_group;
35+
const int lane_id = threadIdx.x % threads_per_group;
36+
37+
const int64_t block_group_id = blockIdx.x * groups_per_block;
38+
const int64_t global_group_id = block_group_id + local_group_id;
39+
const int64_t block_group_offset = global_group_id * group_size;
40+
41+
float local_absmax = eps;
42+
43+
using scale_element_t = float;
44+
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
45+
46+
const T* group_input = input + block_group_offset;
47+
DST_DTYPE* group_output =
48+
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
49+
scale_element_t* scale_output;
50+
51+
if constexpr (IS_COLUMN_MAJOR) {
52+
const int num_elems_per_pack =
53+
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
54+
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
55+
const int row_idx = global_group_id / scale_num_rows_element;
56+
const int col_idx_raw = global_group_id % scale_num_rows_element;
57+
const int col_idx = col_idx_raw / num_elems_per_pack;
58+
const int pack_idx = col_idx_raw % num_elems_per_pack;
59+
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
60+
(col_idx * scale_stride * num_elems_per_pack +
61+
row_idx * num_elems_per_pack + pack_idx);
62+
} else {
63+
scale_output = output_s + global_group_id;
64+
}
65+
66+
// shared memory to cache each group's data to avoid double DRAM reads.
67+
extern __shared__ __align__(16) char smem_raw[];
68+
T* smem = reinterpret_cast<T*>(smem_raw);
69+
T* smem_group = smem + local_group_id * group_size;
70+
71+
constexpr int vec_size = 16 / sizeof(T);
72+
using vec_t = vllm::vec_n_t<T, vec_size>;
73+
74+
// copy global -> shared & compute absmax
75+
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
76+
float abs_v = fabsf(static_cast<float>(src));
77+
local_absmax = fmaxf(local_absmax, abs_v);
78+
dst = src;
79+
};
80+
81+
vllm::vectorize_with_alignment<vec_size>(
82+
group_input, // in
83+
smem_group, // out (shared)
84+
group_size, // elements per group
85+
lane_id, // thread id
86+
threads_per_group, // stride in group
87+
scalar_op_cache); // scalar handler
88+
89+
local_absmax = GroupReduceMax(local_absmax, lane_id);
90+
91+
float y_s = local_absmax / max_8bit;
92+
if constexpr (SCALE_UE8M0) {
93+
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
94+
}
95+
96+
scale_element_t y_s_quant = y_s;
97+
98+
if (lane_id == 0) {
99+
*scale_output = y_s_quant;
100+
}
101+
102+
__syncthreads();
103+
104+
// quantize shared -> global 8-bit
105+
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
106+
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
107+
dst = DST_DTYPE(q);
108+
};
109+
110+
vllm::vectorize_with_alignment<vec_size>(
111+
smem_group, // in (shared)
112+
group_output, // out (global quant tensor)
113+
group_size, // elements
114+
lane_id, // tid
115+
threads_per_group, // stride
116+
scalar_op_quant); // scalar handler
117+
}
118+
119+
void per_token_group_quant_8bit(const torch::Tensor& input,
120+
torch::Tensor& output_q,
121+
torch::Tensor& output_s, int64_t group_size,
122+
double eps, double min_8bit, double max_8bit,
123+
bool scale_ue8m0 = false) {
124+
TORCH_CHECK(input.is_contiguous());
125+
TORCH_CHECK(output_q.is_contiguous());
126+
127+
const int num_groups = input.numel() / group_size;
128+
129+
TORCH_CHECK(input.numel() % group_size == 0);
130+
TORCH_CHECK(output_s.dim() == 2);
131+
132+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
133+
134+
constexpr int THREADS_PER_GROUP = 16;
135+
136+
int groups_per_block = 1;
137+
138+
if (num_groups % 16 == 0) {
139+
groups_per_block = 16;
140+
} else if (num_groups % 8 == 0) {
141+
groups_per_block = 8;
142+
} else if (num_groups % 4 == 0) {
143+
groups_per_block = 4;
144+
} else if (num_groups % 2 == 0) {
145+
groups_per_block = 2;
146+
}
147+
148+
auto dst_type = output_q.scalar_type();
149+
const int num_blocks = num_groups / groups_per_block;
150+
const int num_threads = groups_per_block * THREADS_PER_GROUP;
151+
152+
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
153+
const int scale_num_rows = output_s.size(1);
154+
const int scale_stride = output_s.stride(1);
155+
156+
#define LAUNCH_KERNEL(T, DST_DTYPE) \
157+
do { \
158+
dim3 grid(num_blocks); \
159+
dim3 block(num_threads); \
160+
size_t smem_bytes = \
161+
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
162+
if (is_column_major) { \
163+
if (scale_ue8m0) { \
164+
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
165+
<<<grid, block, smem_bytes, stream>>>( \
166+
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
167+
static_cast<float*>(output_s.data_ptr()), group_size, \
168+
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
169+
(float)max_8bit, scale_num_rows, scale_stride); \
170+
} else { \
171+
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
172+
<<<grid, block, smem_bytes, stream>>>( \
173+
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
174+
static_cast<float*>(output_s.data_ptr()), group_size, \
175+
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
176+
(float)max_8bit, scale_num_rows, scale_stride); \
177+
} \
178+
} else { \
179+
if (scale_ue8m0) { \
180+
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
181+
<<<grid, block, smem_bytes, stream>>>( \
182+
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
183+
static_cast<float*>(output_s.data_ptr()), group_size, \
184+
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
185+
(float)max_8bit); \
186+
} else { \
187+
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
188+
<<<grid, block, smem_bytes, stream>>>( \
189+
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
190+
static_cast<float*>(output_s.data_ptr()), group_size, \
191+
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
192+
(float)max_8bit); \
193+
} \
194+
} \
195+
} while (0)
196+
197+
VLLM_DISPATCH_FLOATING_TYPES(
198+
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
199+
if (dst_type == at::ScalarType::Float8_e4m3fn) {
200+
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
201+
}
202+
}));
203+
204+
#undef LAUNCH_KERNEL
205+
}
206+
207+
void per_token_group_quant_fp8(const torch::Tensor& input,
208+
torch::Tensor& output_q, torch::Tensor& output_s,
209+
int64_t group_size, double eps, double fp8_min,
210+
double fp8_max, bool scale_ue8m0) {
211+
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
212+
fp8_min, fp8_max, scale_ue8m0);
213+
}

csrc/torch_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
601601
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
602602
&dynamic_scaled_int8_quant);
603603

604+
// Compute per-token-group FP8 quantized tensor and scaling factor.
605+
ops.def(
606+
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
607+
"output_s, "
608+
"int group_size, float eps, float fp8_min, float fp8_max, bool "
609+
"scale_ue8m0) -> ()");
610+
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
611+
&per_token_group_quant_fp8);
612+
604613
// Mamba selective scan kernel
605614
ops.def(
606615
"selective_scan_fwd(Tensor! u, Tensor! delta,"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from unittest.mock import patch
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.model_executor.layers.quantization.utils import fp8_utils
9+
10+
11+
@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)])
12+
@pytest.mark.parametrize("column_major", [False, True])
13+
@pytest.mark.parametrize("scale_ue8m0", [False, True])
14+
@pytest.mark.parametrize("group_size", [64, 128])
15+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
16+
def test_per_token_group_quant_fp8(shape, column_major: bool,
17+
scale_ue8m0: bool, group_size: int):
18+
device = "cuda"
19+
20+
torch.manual_seed(42)
21+
num_tokens, hidden_dim = shape
22+
23+
x = (torch.randn(
24+
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
25+
26+
# cuda path
27+
out_q, scale = fp8_utils.per_token_group_quant_fp8(
28+
x,
29+
group_size,
30+
column_major_scales=column_major,
31+
use_ue8m0=scale_ue8m0,
32+
)
33+
34+
# triton ref
35+
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
36+
ref_q, ref_s = fp8_utils.per_token_group_quant_fp8(
37+
x,
38+
group_size,
39+
column_major_scales=column_major,
40+
use_ue8m0=scale_ue8m0,
41+
)
42+
43+
assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15)
44+
assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01)

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
366366
dtype: Optional[torch.dtype] = None,
367367
column_major_scales: bool = False,
368368
out_q: Optional[torch.Tensor] = None,
369+
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
369370
) -> tuple[torch.Tensor, torch.Tensor]:
370371
"""Function to perform per-token-group quantization on an input tensor `x`.
371372
It converts the tensor values into signed float8 values and returns the
@@ -397,8 +398,7 @@ def per_token_group_quant_fp8(
397398
if x_q is None:
398399
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
399400

400-
M = x.numel() // group_size
401-
N = group_size
401+
# Allocate the scale tensor in either row- or column-major format.
402402
if column_major_scales:
403403
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
404404
x_s = torch.empty(shape, device=x.device,
@@ -407,6 +407,15 @@ def per_token_group_quant_fp8(
407407
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
408408
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
409409

410+
# prefer CUDA kernel if available
411+
if current_platform.is_cuda() and x.is_contiguous():
412+
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
413+
fp8_min, fp8_max, use_ue8m0)
414+
return x_q, x_s
415+
416+
# TRITON FALLBACK
417+
M = x.numel() // group_size
418+
N = group_size
410419
BLOCK = triton.next_power_of_2(N)
411420
# heuristics for number of warps
412421
num_warps = min(max(BLOCK // 256, 1), 8)
@@ -423,7 +432,7 @@ def per_token_group_quant_fp8(
423432
eps,
424433
fp8_min=fp8_min,
425434
fp8_max=fp8_max,
426-
use_ue8m0=is_blackwell_deep_gemm_used(),
435+
use_ue8m0=use_ue8m0,
427436
BLOCK=BLOCK,
428437
num_warps=num_warps,
429438
num_stages=num_stages,
@@ -439,7 +448,7 @@ def per_token_group_quant_fp8(
439448
eps,
440449
fp8_min=fp8_min,
441450
fp8_max=fp8_max,
442-
use_ue8m0=is_blackwell_deep_gemm_used(),
451+
use_ue8m0=use_ue8m0,
443452
BLOCK=BLOCK,
444453
num_warps=num_warps,
445454
num_stages=num_stages,

0 commit comments

Comments
 (0)