Skip to content

Commit 4771df7

Browse files
yewentao256mgoin
andauthored
[Feature] Non-contiguous Support for FP8 Quantization (#21961)
Signed-off-by: yewentao256 <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent 05fae02 commit 4771df7

File tree

4 files changed

+207
-201
lines changed

4 files changed

+207
-201
lines changed

csrc/quantization/fp8/common.cu

Lines changed: 170 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include "common.cuh"
22
#include "dispatch_utils.h"
3-
3+
#include "../vectorization_utils.cuh"
44
#include <c10/cuda/CUDAGuard.h>
5+
#include <ATen/cuda/Exceptions.h>
56

67
#ifndef USE_ROCM
78
#include <cub/cub.cuh>
@@ -12,74 +13,127 @@
1213
namespace vllm {
1314

1415
template <typename scalar_t, typename fp8_type>
15-
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
16-
const scalar_t* __restrict__ input,
17-
const float* __restrict__ scale,
18-
int64_t num_elems) {
19-
int tid = blockDim.x * blockIdx.x + threadIdx.x;
20-
21-
// Invert the scale so that we can use multiplications to avoid expensive
22-
// division.
23-
const float inverted_scale = 1.0f / (*scale);
24-
scaled_fp8_conversion_vec<scalar_t, true>(
25-
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
16+
__global__ void scaled_fp8_quant_kernel_strided(
17+
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
18+
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
19+
int64_t out_row_stride) {
20+
const int64_t token_idx = blockIdx.x; // one token per block
21+
const int tid = threadIdx.x;
22+
23+
const scalar_t* token_in = input + token_idx * in_row_stride;
24+
fp8_type* token_out = out + token_idx * out_row_stride;
25+
26+
const float inv_scale = 1.0f / (*scale);
27+
28+
vectorize_with_alignment<16>(
29+
token_in, token_out, hidden_size, tid, blockDim.x,
30+
[=] __device__(fp8_type & dst, const scalar_t& src) {
31+
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
32+
inv_scale);
33+
});
2634
}
2735

2836
template <typename scalar_t, typename fp8_type>
29-
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30-
fp8_type* __restrict__ out, float* __restrict__ scale,
31-
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
32-
const int hidden_size) {
33-
int const tid = threadIdx.x;
34-
int const token_idx = blockIdx.x;
37+
__global__ void segmented_max_reduction_strided(
38+
float* __restrict__ scale, const scalar_t* __restrict__ input,
39+
int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
40+
__shared__ float cache[256];
41+
const int tid = threadIdx.x;
42+
int64_t token_idx = blockIdx.x;
43+
44+
// one block per token. Guard in case gridDim.x > num_tokens.
45+
if (token_idx >= num_tokens) {
46+
return;
47+
}
3548

36-
// Use int64 to avoid overflowing an int32 when calculating this offset
37-
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
38-
scalar_t const* __restrict__ token_input = &input[offset];
39-
fp8_type* __restrict__ token_output = &out[offset];
40-
41-
// For vectorization, token_input and token_output pointers need to be
42-
// aligned at 32-byte and 16-byte addresses respectively.
43-
bool const can_vectorize = hidden_size % 16 == 0;
44-
45-
float absmax_val = 0.0f;
46-
if (can_vectorize) {
47-
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
48-
} else {
49-
for (int i = tid; i < hidden_size; i += blockDim.x) {
50-
float const x = static_cast<float>(token_input[i]);
51-
absmax_val = fmaxf(absmax_val, fabsf(x));
49+
const scalar_t* row_ptr = input + token_idx * in_row_stride;
50+
51+
// each thread scans elements of the row in a strided fashion.
52+
float thread_max = 0.0f;
53+
for (int e = tid; e < hidden_size; e += blockDim.x) {
54+
float v = fabsf(static_cast<float>(row_ptr[e]));
55+
thread_max = fmaxf(thread_max, v);
56+
}
57+
58+
cache[tid] = thread_max;
59+
__syncthreads();
60+
61+
// parallel reduction to find row max.
62+
for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) {
63+
if (tid < offset) {
64+
cache[tid] = fmaxf(cache[tid], cache[tid + offset]);
5265
}
66+
__syncthreads();
5367
}
5468

69+
// thread 0 updates global scale (per-tensor) atomically.
70+
if (tid == 0) {
71+
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
72+
}
73+
}
74+
75+
template <typename scalar_t, typename fp8_type>
76+
__global__ void scaled_fp8_quant_kernel_strided_dynamic(
77+
fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
78+
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
79+
int64_t out_row_stride) {
80+
const int64_t token_idx = blockIdx.x;
81+
const int tid = threadIdx.x;
82+
83+
const scalar_t* token_in = input + token_idx * in_row_stride;
84+
fp8_type* token_out = out + token_idx * out_row_stride;
85+
86+
const float reciprocal_scale = 1.0f / (*scale);
87+
vectorize_with_alignment<16>(
88+
token_in, token_out, hidden_size, tid, blockDim.x,
89+
[=] __device__(fp8_type & dst, const scalar_t& src) {
90+
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
91+
reciprocal_scale);
92+
});
93+
}
94+
95+
template <typename scalar_t, typename fp8_type>
96+
__global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
97+
fp8_type* __restrict__ out, float* __restrict__ scale,
98+
const scalar_t* __restrict__ input, const float* __restrict__ scale_ub,
99+
int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
100+
const int64_t token_idx = blockIdx.x;
101+
const int tid = threadIdx.x;
102+
103+
// Use int64 to avoid overflowing an int32 when calculating this offset
104+
int64_t in_offset = static_cast<int64_t>(token_idx) * in_row_stride;
105+
int64_t out_offset = static_cast<int64_t>(token_idx) * out_row_stride;
106+
const scalar_t* token_in = input + in_offset;
107+
fp8_type* token_out = out + out_offset;
108+
109+
// 1) per-token absmax
110+
float absmax_val = 0.f;
111+
vectorize_read_with_alignment<16>(
112+
token_in, hidden_size, tid, blockDim.x, [&] __device__(scalar_t v) {
113+
absmax_val = fmaxf(absmax_val, fabsf(static_cast<float>(v)));
114+
});
115+
55116
using BlockReduce = cub::BlockReduce<float, 256>;
56-
__shared__ typename BlockReduce::TempStorage reduceStorage;
57-
float const block_absmax_val_maybe =
58-
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
117+
__shared__ typename BlockReduce::TempStorage tmp;
118+
const float block_max =
119+
BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
120+
59121
__shared__ float token_scale;
60122
if (tid == 0) {
61-
if (scale_ub) {
62-
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
63-
} else {
64-
token_scale = block_absmax_val_maybe;
65-
}
66-
// token scale computation
123+
token_scale = scale_ub ? fminf(block_max, *scale_ub) : block_max;
67124
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
68125
min_scaling_factor<fp8_type>::val());
69126
scale[token_idx] = token_scale;
70127
}
71128
__syncthreads();
72129

73-
// Note that we don't use inverted scales so we can match FBGemm impl.
74-
if (can_vectorize) {
75-
scaled_fp8_conversion_vec<scalar_t, false>(
76-
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
77-
} else {
78-
for (int i = tid; i < hidden_size; i += blockDim.x) {
79-
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
80-
static_cast<float>(token_input[i]), token_scale);
81-
}
82-
}
130+
// 2) quantize
131+
vectorize_with_alignment<16>(
132+
token_in, token_out, hidden_size, tid, blockDim.x,
133+
[=] __device__(fp8_type & dst, const scalar_t& src) {
134+
dst = scaled_fp8_conversion<false, fp8_type>(static_cast<float>(src),
135+
token_scale);
136+
});
83137
}
84138

85139
} // namespace vllm
@@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
88142
torch::Tensor const& input, // [..., d]
89143
torch::Tensor const& scale) // [1]
90144
{
91-
TORCH_CHECK(input.is_contiguous());
92-
TORCH_CHECK(out.is_contiguous());
93-
int const block_size = 256;
94-
int const num_tokens = input.numel() / input.size(-1);
95-
int const num_elems = input.numel();
96-
dim3 const grid(num_tokens);
97-
dim3 const block(block_size);
145+
TORCH_CHECK(input.stride(-1) == 1,
146+
"last dimension of input must be contiguous");
147+
TORCH_CHECK(out.stride(-1) == 1,
148+
"last dimension of output must be contiguous");
149+
150+
const int hidden_size = input.size(-1);
151+
const int num_tokens = input.numel() / hidden_size;
152+
const int block_size = 256;
153+
dim3 grid(num_tokens);
154+
dim3 block(block_size);
155+
156+
const int64_t in_row_stride = input.stride(-2);
157+
const int64_t out_row_stride = out.stride(-2);
158+
98159
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
99160
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
100161
VLLM_DISPATCH_FLOATING_TYPES(
101162
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
102163
VLLM_DISPATCH_FP8_TYPES(
103164
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
104-
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
165+
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t>
105166
<<<grid, block, 0, stream>>>(
106167
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
107-
scale.data_ptr<float>(), num_elems);
168+
scale.data_ptr<float>(), hidden_size, in_row_stride,
169+
out_row_stride);
108170
});
109171
});
110172
}
@@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
113175
torch::Tensor const& input, // [..., d]
114176
torch::Tensor& scale) // [1]
115177
{
116-
TORCH_CHECK(input.is_contiguous());
117-
TORCH_CHECK(out.is_contiguous());
118-
int const block_size = 256;
119-
int const num_tokens = input.numel() / input.size(-1);
120-
int const num_elems = input.numel();
121-
dim3 const grid(num_tokens);
122-
dim3 const block(block_size);
178+
TORCH_CHECK(input.stride(-1) == 1,
179+
"last dimension of input must be contiguous");
180+
TORCH_CHECK(out.stride(-1) == 1,
181+
"last dimension of output must be contiguous");
182+
183+
const int hidden_size = input.size(-1);
184+
const int num_tokens = input.numel() / hidden_size;
185+
const int block_size = 256;
186+
dim3 grid(num_tokens);
187+
dim3 block(block_size);
188+
189+
const int64_t in_row_stride = input.stride(-2);
190+
const int64_t out_row_stride = out.stride(-2);
191+
123192
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
124193
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
194+
195+
// scale tensor should be initialised to <=0 before reduction
196+
AT_CUDA_CHECK(
197+
cudaMemsetAsync(scale.data_ptr<float>(), 0, sizeof(float), stream));
198+
125199
VLLM_DISPATCH_FLOATING_TYPES(
126200
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
127201
VLLM_DISPATCH_FP8_TYPES(
128202
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
129-
vllm::segmented_max_reduction<scalar_t, fp8_t>
130-
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
131-
input.data_ptr<scalar_t>(),
132-
num_elems);
133-
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
203+
vllm::segmented_max_reduction_strided<scalar_t, fp8_t>
204+
<<<grid, block, 0, stream>>>(
205+
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
206+
hidden_size, in_row_stride,
207+
static_cast<int64_t>(num_tokens));
208+
209+
vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t, fp8_t>
134210
<<<grid, block, 0, stream>>>(
135211
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
136-
scale.data_ptr<float>(), num_elems);
212+
scale.data_ptr<float>(), hidden_size, in_row_stride,
213+
out_row_stride);
137214
});
138215
});
139216
}
@@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant(
142219
torch::Tensor& out, // [..., d]
143220
torch::Tensor const& input, // [..., d]
144221
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
145-
TORCH_CHECK(input.is_contiguous());
146-
TORCH_CHECK(out.is_contiguous());
222+
TORCH_CHECK(input.stride(-1) == 1,
223+
"last dimension of input must be contiguous");
224+
TORCH_CHECK(out.stride(-1) == 1,
225+
"last dimension of output must be contiguous");
147226

148-
int const hidden_size = input.size(-1);
149-
int const num_tokens = input.numel() / hidden_size;
150-
int const block_size = 256;
151-
dim3 const grid(num_tokens);
152-
dim3 const block(std::min(hidden_size, block_size));
227+
const int hidden_size = input.size(-1);
228+
const int num_tokens = input.numel() / hidden_size;
229+
const int block_size = 256;
230+
dim3 grid(num_tokens);
231+
dim3 block(std::min(hidden_size, block_size));
232+
233+
const int64_t in_row_stride = input.stride(-2);
234+
const int64_t out_row_stride = out.stride(-2);
153235

154236
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
155237
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant(
159241
VLLM_DISPATCH_FP8_TYPES(
160242
out.scalar_type(),
161243
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
162-
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
163-
<<<grid, block, 0, stream>>>(
164-
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
165-
input.data_ptr<scalar_t>(),
166-
scale_ub.has_value() ? scale_ub->data_ptr<float>()
167-
: nullptr,
168-
hidden_size);
244+
vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
245+
scalar_t, fp8_t><<<grid, block, 0, stream>>>(
246+
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
247+
input.data_ptr<scalar_t>(),
248+
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
249+
hidden_size, in_row_stride, out_row_stride);
169250
});
170251
});
171252
}

0 commit comments

Comments
 (0)