1
1
#include " common.cuh"
2
2
#include " dispatch_utils.h"
3
-
3
+ # include " ../vectorization_utils.cuh "
4
4
#include < c10/cuda/CUDAGuard.h>
5
+ #include < ATen/cuda/Exceptions.h>
5
6
6
7
#ifndef USE_ROCM
7
8
#include < cub/cub.cuh>
12
13
namespace vllm {
13
14
14
15
template <typename scalar_t , typename fp8_type>
15
- __global__ void scaled_fp8_quant_kernel (fp8_type* __restrict__ out,
16
- const scalar_t * __restrict__ input,
17
- const float * __restrict__ scale,
18
- int64_t num_elems) {
19
- int tid = blockDim .x * blockIdx .x + threadIdx .x ;
20
-
21
- // Invert the scale so that we can use multiplications to avoid expensive
22
- // division.
23
- const float inverted_scale = 1 .0f / (*scale);
24
- scaled_fp8_conversion_vec<scalar_t , true >(
25
- out, input, inverted_scale, num_elems, tid, blockDim .x * gridDim .x );
16
+ __global__ void scaled_fp8_quant_kernel_strided (
17
+ fp8_type* __restrict__ out, const scalar_t * __restrict__ input,
18
+ const float * __restrict__ scale, int hidden_size, int64_t in_row_stride,
19
+ int64_t out_row_stride) {
20
+ const int64_t token_idx = blockIdx .x ; // one token per block
21
+ const int tid = threadIdx .x ;
22
+
23
+ const scalar_t * token_in = input + token_idx * in_row_stride;
24
+ fp8_type* token_out = out + token_idx * out_row_stride;
25
+
26
+ const float inv_scale = 1 .0f / (*scale);
27
+
28
+ vectorize_with_alignment<16 >(
29
+ token_in, token_out, hidden_size, tid, blockDim .x ,
30
+ [=] __device__ (fp8_type & dst, const scalar_t & src) {
31
+ dst = scaled_fp8_conversion<true , fp8_type>(static_cast <float >(src),
32
+ inv_scale);
33
+ });
26
34
}
27
35
28
36
template <typename scalar_t , typename fp8_type>
29
- __global__ void dynamic_per_token_scaled_fp8_quant_kernel (
30
- fp8_type* __restrict__ out, float * __restrict__ scale,
31
- scalar_t const * __restrict__ input, float const * __restrict__ scale_ub,
32
- const int hidden_size) {
33
- int const tid = threadIdx .x ;
34
- int const token_idx = blockIdx .x ;
37
+ __global__ void segmented_max_reduction_strided (
38
+ float * __restrict__ scale, const scalar_t * __restrict__ input,
39
+ int hidden_size, int64_t in_row_stride, int64_t num_tokens) {
40
+ __shared__ float cache[256 ];
41
+ const int tid = threadIdx .x ;
42
+ int64_t token_idx = blockIdx .x ;
43
+
44
+ // one block per token. Guard in case gridDim.x > num_tokens.
45
+ if (token_idx >= num_tokens) {
46
+ return ;
47
+ }
35
48
36
- // Use int64 to avoid overflowing an int32 when calculating this offset
37
- int64_t offset = static_cast < int64_t >(token_idx) * hidden_size;
38
- scalar_t const * __restrict__ token_input = &input[offset];
39
- fp8_type* __restrict__ token_output = &out[offset] ;
40
-
41
- // For vectorization, token_input and token_output pointers need to be
42
- // aligned at 32-byte and 16-byte addresses respectively.
43
- bool const can_vectorize = hidden_size % 16 == 0 ;
44
-
45
- float absmax_val = 0 . 0f ;
46
- if (can_vectorize) {
47
- absmax_val = thread_max_vec (token_input, hidden_size, tid, blockDim . x );
48
- } else {
49
- for (int i = tid; i < hidden_size; i += blockDim . x ) {
50
- float const x = static_cast < float >(token_input[i]);
51
- absmax_val = fmaxf (absmax_val, fabsf (x) );
49
+ const scalar_t * row_ptr = input + token_idx * in_row_stride;
50
+
51
+ // each thread scans elements of the row in a strided fashion.
52
+ float thread_max = 0 . 0f ;
53
+ for ( int e = tid; e < hidden_size; e += blockDim . x ) {
54
+ float v = fabsf ( static_cast < float >(row_ptr[e]));
55
+ thread_max = fmaxf (thread_max, v);
56
+ }
57
+
58
+ cache[tid] = thread_max ;
59
+ __syncthreads ();
60
+
61
+ // parallel reduction to find row max.
62
+ for (int offset = blockDim . x / 2 ; offset > 0 ; offset >>= 1 ) {
63
+ if (tid < offset) {
64
+ cache[tid] = fmaxf (cache[tid], cache[tid + offset] );
52
65
}
66
+ __syncthreads ();
53
67
}
54
68
69
+ // thread 0 updates global scale (per-tensor) atomically.
70
+ if (tid == 0 ) {
71
+ atomicMaxFloat (scale, cache[0 ] / quant_type_max_v<fp8_type>);
72
+ }
73
+ }
74
+
75
+ template <typename scalar_t , typename fp8_type>
76
+ __global__ void scaled_fp8_quant_kernel_strided_dynamic (
77
+ fp8_type* __restrict__ out, const scalar_t * __restrict__ input,
78
+ const float * __restrict__ scale, int hidden_size, int64_t in_row_stride,
79
+ int64_t out_row_stride) {
80
+ const int64_t token_idx = blockIdx .x ;
81
+ const int tid = threadIdx .x ;
82
+
83
+ const scalar_t * token_in = input + token_idx * in_row_stride;
84
+ fp8_type* token_out = out + token_idx * out_row_stride;
85
+
86
+ const float reciprocal_scale = 1 .0f / (*scale);
87
+ vectorize_with_alignment<16 >(
88
+ token_in, token_out, hidden_size, tid, blockDim .x ,
89
+ [=] __device__ (fp8_type & dst, const scalar_t & src) {
90
+ dst = scaled_fp8_conversion<true , fp8_type>(static_cast <float >(src),
91
+ reciprocal_scale);
92
+ });
93
+ }
94
+
95
+ template <typename scalar_t , typename fp8_type>
96
+ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided (
97
+ fp8_type* __restrict__ out, float * __restrict__ scale,
98
+ const scalar_t * __restrict__ input, const float * __restrict__ scale_ub,
99
+ int hidden_size, int64_t in_row_stride, int64_t out_row_stride) {
100
+ const int64_t token_idx = blockIdx .x ;
101
+ const int tid = threadIdx .x ;
102
+
103
+ // Use int64 to avoid overflowing an int32 when calculating this offset
104
+ int64_t in_offset = static_cast <int64_t >(token_idx) * in_row_stride;
105
+ int64_t out_offset = static_cast <int64_t >(token_idx) * out_row_stride;
106
+ const scalar_t * token_in = input + in_offset;
107
+ fp8_type* token_out = out + out_offset;
108
+
109
+ // 1) per-token absmax
110
+ float absmax_val = 0 .f ;
111
+ vectorize_read_with_alignment<16 >(
112
+ token_in, hidden_size, tid, blockDim .x , [&] __device__ (scalar_t v) {
113
+ absmax_val = fmaxf (absmax_val, fabsf (static_cast <float >(v)));
114
+ });
115
+
55
116
using BlockReduce = cub::BlockReduce<float , 256 >;
56
- __shared__ typename BlockReduce::TempStorage reduceStorage;
57
- float const block_absmax_val_maybe =
58
- BlockReduce (reduceStorage).Reduce (absmax_val, cub::Max{}, blockDim .x );
117
+ __shared__ typename BlockReduce::TempStorage tmp;
118
+ const float block_max =
119
+ BlockReduce (tmp).Reduce (absmax_val, cub::Max{}, blockDim .x );
120
+
59
121
__shared__ float token_scale;
60
122
if (tid == 0 ) {
61
- if (scale_ub) {
62
- token_scale = fminf (block_absmax_val_maybe, *scale_ub);
63
- } else {
64
- token_scale = block_absmax_val_maybe;
65
- }
66
- // token scale computation
123
+ token_scale = scale_ub ? fminf (block_max, *scale_ub) : block_max;
67
124
token_scale = fmaxf (token_scale / quant_type_max_v<fp8_type>,
68
125
min_scaling_factor<fp8_type>::val ());
69
126
scale[token_idx] = token_scale;
70
127
}
71
128
__syncthreads ();
72
129
73
- // Note that we don't use inverted scales so we can match FBGemm impl.
74
- if (can_vectorize) {
75
- scaled_fp8_conversion_vec<scalar_t , false >(
76
- token_output, token_input, token_scale, hidden_size, tid, blockDim .x );
77
- } else {
78
- for (int i = tid; i < hidden_size; i += blockDim .x ) {
79
- token_output[i] = scaled_fp8_conversion<false , fp8_type>(
80
- static_cast <float >(token_input[i]), token_scale);
81
- }
82
- }
130
+ // 2) quantize
131
+ vectorize_with_alignment<16 >(
132
+ token_in, token_out, hidden_size, tid, blockDim .x ,
133
+ [=] __device__ (fp8_type & dst, const scalar_t & src) {
134
+ dst = scaled_fp8_conversion<false , fp8_type>(static_cast <float >(src),
135
+ token_scale);
136
+ });
83
137
}
84
138
85
139
} // namespace vllm
@@ -88,23 +142,31 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
88
142
torch::Tensor const & input, // [..., d]
89
143
torch::Tensor const & scale) // [1]
90
144
{
91
- TORCH_CHECK (input.is_contiguous ());
92
- TORCH_CHECK (out.is_contiguous ());
93
- int const block_size = 256 ;
94
- int const num_tokens = input.numel () / input.size (-1 );
95
- int const num_elems = input.numel ();
96
- dim3 const grid (num_tokens);
97
- dim3 const block (block_size);
145
+ TORCH_CHECK (input.stride (-1 ) == 1 ,
146
+ " last dimension of input must be contiguous" );
147
+ TORCH_CHECK (out.stride (-1 ) == 1 ,
148
+ " last dimension of output must be contiguous" );
149
+
150
+ const int hidden_size = input.size (-1 );
151
+ const int num_tokens = input.numel () / hidden_size;
152
+ const int block_size = 256 ;
153
+ dim3 grid (num_tokens);
154
+ dim3 block (block_size);
155
+
156
+ const int64_t in_row_stride = input.stride (-2 );
157
+ const int64_t out_row_stride = out.stride (-2 );
158
+
98
159
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
99
160
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
100
161
VLLM_DISPATCH_FLOATING_TYPES (
101
162
input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" , [&] {
102
163
VLLM_DISPATCH_FP8_TYPES (
103
164
out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" , [&] {
104
- vllm::scaled_fp8_quant_kernel <scalar_t , fp8_t >
165
+ vllm::scaled_fp8_quant_kernel_strided <scalar_t , fp8_t >
105
166
<<<grid, block, 0 , stream>>> (
106
167
out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
107
- scale.data_ptr <float >(), num_elems);
168
+ scale.data_ptr <float >(), hidden_size, in_row_stride,
169
+ out_row_stride);
108
170
});
109
171
});
110
172
}
@@ -113,27 +175,42 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
113
175
torch::Tensor const & input, // [..., d]
114
176
torch::Tensor& scale) // [1]
115
177
{
116
- TORCH_CHECK (input.is_contiguous ());
117
- TORCH_CHECK (out.is_contiguous ());
118
- int const block_size = 256 ;
119
- int const num_tokens = input.numel () / input.size (-1 );
120
- int const num_elems = input.numel ();
121
- dim3 const grid (num_tokens);
122
- dim3 const block (block_size);
178
+ TORCH_CHECK (input.stride (-1 ) == 1 ,
179
+ " last dimension of input must be contiguous" );
180
+ TORCH_CHECK (out.stride (-1 ) == 1 ,
181
+ " last dimension of output must be contiguous" );
182
+
183
+ const int hidden_size = input.size (-1 );
184
+ const int num_tokens = input.numel () / hidden_size;
185
+ const int block_size = 256 ;
186
+ dim3 grid (num_tokens);
187
+ dim3 block (block_size);
188
+
189
+ const int64_t in_row_stride = input.stride (-2 );
190
+ const int64_t out_row_stride = out.stride (-2 );
191
+
123
192
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
124
193
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
194
+
195
+ // scale tensor should be initialised to <=0 before reduction
196
+ AT_CUDA_CHECK (
197
+ cudaMemsetAsync (scale.data_ptr <float >(), 0 , sizeof (float ), stream));
198
+
125
199
VLLM_DISPATCH_FLOATING_TYPES (
126
200
input.scalar_type (), " scaled_fp8_quant_kernel_scalar_type" , [&] {
127
201
VLLM_DISPATCH_FP8_TYPES (
128
202
out.scalar_type (), " scaled_fp8_quant_kernel_fp8_type" , [&] {
129
- vllm::segmented_max_reduction<scalar_t , fp8_t >
130
- <<<grid, block, 0 , stream>>> (scale.data_ptr <float >(),
131
- input.data_ptr <scalar_t >(),
132
- num_elems);
133
- vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >
203
+ vllm::segmented_max_reduction_strided<scalar_t , fp8_t >
204
+ <<<grid, block, 0 , stream>>> (
205
+ scale.data_ptr <float >(), input.data_ptr <scalar_t >(),
206
+ hidden_size, in_row_stride,
207
+ static_cast <int64_t >(num_tokens));
208
+
209
+ vllm::scaled_fp8_quant_kernel_strided_dynamic<scalar_t , fp8_t >
134
210
<<<grid, block, 0 , stream>>> (
135
211
out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
136
- scale.data_ptr <float >(), num_elems);
212
+ scale.data_ptr <float >(), hidden_size, in_row_stride,
213
+ out_row_stride);
137
214
});
138
215
});
139
216
}
@@ -142,14 +219,19 @@ void dynamic_per_token_scaled_fp8_quant(
142
219
torch::Tensor& out, // [..., d]
143
220
torch::Tensor const & input, // [..., d]
144
221
torch::Tensor& scales, std::optional<at::Tensor> const & scale_ub) {
145
- TORCH_CHECK (input.is_contiguous ());
146
- TORCH_CHECK (out.is_contiguous ());
222
+ TORCH_CHECK (input.stride (-1 ) == 1 ,
223
+ " last dimension of input must be contiguous" );
224
+ TORCH_CHECK (out.stride (-1 ) == 1 ,
225
+ " last dimension of output must be contiguous" );
147
226
148
- int const hidden_size = input.size (-1 );
149
- int const num_tokens = input.numel () / hidden_size;
150
- int const block_size = 256 ;
151
- dim3 const grid (num_tokens);
152
- dim3 const block (std::min (hidden_size, block_size));
227
+ const int hidden_size = input.size (-1 );
228
+ const int num_tokens = input.numel () / hidden_size;
229
+ const int block_size = 256 ;
230
+ dim3 grid (num_tokens);
231
+ dim3 block (std::min (hidden_size, block_size));
232
+
233
+ const int64_t in_row_stride = input.stride (-2 );
234
+ const int64_t out_row_stride = out.stride (-2 );
153
235
154
236
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
155
237
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
@@ -159,13 +241,12 @@ void dynamic_per_token_scaled_fp8_quant(
159
241
VLLM_DISPATCH_FP8_TYPES (
160
242
out.scalar_type (),
161
243
" dynamic_per_token_scaled_fp8_quant_kernel_fp8_type" , [&] {
162
- vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t , fp8_t >
163
- <<<grid, block, 0 , stream>>> (
164
- out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
165
- input.data_ptr <scalar_t >(),
166
- scale_ub.has_value () ? scale_ub->data_ptr <float >()
167
- : nullptr ,
168
- hidden_size);
244
+ vllm::dynamic_per_token_scaled_fp8_quant_kernel_strided<
245
+ scalar_t , fp8_t ><<<grid, block, 0 , stream>>> (
246
+ out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
247
+ input.data_ptr <scalar_t >(),
248
+ scale_ub.has_value () ? scale_ub->data_ptr <float >() : nullptr ,
249
+ hidden_size, in_row_stride, out_row_stride);
169
250
});
170
251
});
171
252
}
0 commit comments