8
8
#include " xpu/ops.h"
9
9
10
10
#include " fp8_quant.h"
11
- #include " utils.h"
12
-
11
+ #include " quant_utils.h"
13
12
14
13
namespace vllm {
15
14
@@ -22,25 +21,19 @@ class scaled_fp8_quant_kernel {
22
21
int64_t num_elems;
23
22
24
23
public:
25
- scaled_fp8_quant_kernel (
26
- fp8_type* out_,
27
- const scalar_t * input_,
28
- const float * scale_,
29
- int64_t num_elems_)
24
+ scaled_fp8_quant_kernel (fp8_type* out_, const scalar_t * input_,
25
+ const float * scale_, int64_t num_elems_)
30
26
: out(out_), input(input_), scale(scale_), num_elems(num_elems_) {}
31
27
void operator ()(sycl::nd_item<1 > item) const {
32
28
int tid = item.get_global_linear_id ();
33
29
34
30
// Invert the scale so that we can use multiplications to avoid expensive
35
31
// division.
36
32
const float inverted_scale = 1 .0f / (*scale);
37
- scaled_fp8_conversion_vec<scalar_t , true >(
38
- out,
39
- input,
40
- inverted_scale,
41
- num_elems,
42
- tid,
43
- item.get_local_range (0 ) * item.get_group_range (0 ));
33
+ fp8::ConvertWithScaleOp<true , fp8_type> op{inverted_scale};
34
+ fp8::scaled_convert_vec (input, out, num_elems, tid,
35
+ item.get_local_range (0 ) * item.get_group_range (0 ),
36
+ op);
44
37
}
45
38
};
46
39
@@ -54,12 +47,10 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
54
47
const int hidden_size;
55
48
56
49
public:
57
- dynamic_per_token_scaled_fp8_quant_kernel (
58
- fp8_type* out_,
59
- float * scale_,
60
- scalar_t const * input_,
61
- float const * scale_ub_,
62
- const int hidden_size_)
50
+ dynamic_per_token_scaled_fp8_quant_kernel (fp8_type* out_, float * scale_,
51
+ scalar_t const * input_,
52
+ float const * scale_ub_,
53
+ const int hidden_size_)
63
54
: out(out_),
64
55
scale (scale_),
65
56
input(input_),
@@ -70,13 +61,6 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
70
61
int const tid = item.get_local_id (0 );
71
62
int const token_idx = item.get_group (0 );
72
63
73
- // sycl::ext::oneapi::experimental::printf(
74
- // "token_idx: %d, tid: %d, hidden_size: %d, group_range: %d\n",
75
- // token_idx,
76
- // tid,
77
- // hidden_size,
78
- // item.get_local_range(0));
79
-
80
64
// Use int64 to avoid overflowing an int32 when calculating this offset
81
65
int64_t offset = static_cast <int64_t >(token_idx) * hidden_size;
82
66
scalar_t const * token_input = &input[offset];
@@ -88,8 +72,8 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
88
72
89
73
float absmax_val = 0 .0f ;
90
74
if (can_vectorize) {
91
- absmax_val = thread_max_vec (
92
- token_input, hidden_size, tid, item.get_local_range (0 ));
75
+ absmax_val = thread_max_vec (token_input, hidden_size, tid,
76
+ item.get_local_range (0 ));
93
77
} else {
94
78
for (int i = tid; i < hidden_size; i += item.get_local_range (0 )) {
95
79
float const x = static_cast <float >(token_input[i]);
@@ -110,38 +94,33 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
110
94
token_scale[0 ] = block_absmax_val_maybe;
111
95
}
112
96
// token scale computation
113
- token_scale[0 ] = sycl::max (
114
- token_scale[0 ] / quant_type_max_v<fp8_type>,
115
- min_scaling_factor<fp8_type>::val ());
97
+ token_scale[0 ] =
98
+ sycl::max ( token_scale[0 ] / fp8:: quant_type_max_v<fp8_type>,
99
+ fp8:: min_scaling_factor<fp8_type>::val ());
116
100
scale[token_idx] = token_scale[0 ];
117
101
}
118
102
group_barrier (item.get_group ());
119
103
120
104
// Note that we don't use inverted scales so we can match FBGemm impl.
121
105
const float inverted_scale = 1 .0f / (token_scale[0 ]);
122
106
if (can_vectorize) {
123
- scaled_fp8_conversion_vec<scalar_t , true >(
124
- token_output,
125
- token_input,
126
- inverted_scale,
127
- hidden_size,
128
- tid,
129
- item.get_local_range (0 ));
107
+ fp8::ConvertWithScaleOp<true , fp8_type> op{inverted_scale};
108
+ fp8::scaled_convert_vec (token_input, token_output, hidden_size, tid,
109
+ item.get_local_range (0 ), op);
130
110
} else {
131
111
for (int i = tid; i < hidden_size; i += item.get_local_range (0 )) {
132
- token_output[i] = scaled_fp8_conversion <true , fp8_type>(
133
- static_cast < float >(token_input [i]), inverted_scale, tid, token_idx );
112
+ fp8::ConvertWithScaleOp <true , fp8_type> op{inverted_scale};
113
+ op (token_output [i], token_input[i] );
134
114
}
135
115
}
136
116
}
137
117
};
138
118
139
- } // namespace vllm
119
+ } // namespace vllm
140
120
141
- void static_scaled_fp8_quant (
142
- torch::Tensor& out, // [..., d]
143
- torch::Tensor const & input, // [..., d]
144
- torch::Tensor const & scale) // [1]
121
+ void static_scaled_fp8_quant (torch::Tensor& out, // [..., d]
122
+ torch::Tensor const & input, // [..., d]
123
+ torch::Tensor const & scale) // [1]
145
124
{
146
125
int64_t num_tokens = input.numel () / input.size (-1 );
147
126
int64_t num_elems = input.numel ();
@@ -158,21 +137,18 @@ void static_scaled_fp8_quant(
158
137
// Launch the kernel
159
138
stream.submit ([&](sycl::handler& cgh) {
160
139
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >(
161
- out.data_ptr <fp8_t >(),
162
- input.data_ptr <scalar_t >(),
163
- scale.data_ptr <float >(),
164
- num_elems);
165
- cgh.parallel_for (
166
- sycl::nd_range<1 >(grid * block, block), kernel);
140
+ out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
141
+ scale.data_ptr <float >(), num_elems);
142
+ cgh.parallel_for (sycl::nd_range<1 >(grid * block, block),
143
+ kernel);
167
144
});
168
145
});
169
146
});
170
147
}
171
148
172
- void dynamic_scaled_fp8_quant (
173
- torch::Tensor& out, // [..., d]
174
- torch::Tensor const & input, // [..., d]
175
- torch::Tensor& scale) // [1]
149
+ void dynamic_scaled_fp8_quant (torch::Tensor& out, // [..., d]
150
+ torch::Tensor const & input, // [..., d]
151
+ torch::Tensor& scale) // [1]
176
152
{
177
153
int64_t num_tokens = input.numel () / input.size (-1 );
178
154
int64_t num_elems = input.numel ();
@@ -190,30 +166,26 @@ void dynamic_scaled_fp8_quant(
190
166
stream.submit ([&](sycl::handler& cgh) {
191
167
auto max_reduce_kernel =
192
168
vllm::segmented_max_reduction<scalar_t , fp8_t >(
193
- scale.data_ptr <float >(),
194
- input.data_ptr <scalar_t >(),
169
+ scale.data_ptr <float >(), input.data_ptr <scalar_t >(),
195
170
num_elems);
196
- cgh.parallel_for (
197
- sycl::nd_range< 1 >(grid * block, block), max_reduce_kernel);
171
+ cgh.parallel_for (sycl::nd_range< 1 >(grid * block, block),
172
+ max_reduce_kernel);
198
173
});
199
174
stream.submit ([&](sycl::handler& cgh) {
200
175
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t , fp8_t >(
201
- out.data_ptr <fp8_t >(),
202
- input.data_ptr <scalar_t >(),
203
- scale.data_ptr <float >(),
204
- num_elems);
205
- cgh.parallel_for (
206
- sycl::nd_range<1 >(grid * block, block), kernel);
176
+ out.data_ptr <fp8_t >(), input.data_ptr <scalar_t >(),
177
+ scale.data_ptr <float >(), num_elems);
178
+ cgh.parallel_for (sycl::nd_range<1 >(grid * block, block),
179
+ kernel);
207
180
});
208
181
});
209
182
});
210
183
}
211
184
212
185
void dynamic_per_token_scaled_fp8_quant (
213
- torch::Tensor& out, // [..., d]
214
- torch::Tensor const & input, // [..., d]
215
- torch::Tensor& scales,
216
- std::optional<at::Tensor> const & scale_ub) {
186
+ torch::Tensor& out, // [..., d]
187
+ torch::Tensor const & input, // [..., d]
188
+ torch::Tensor& scales, std::optional<at::Tensor> const & scale_ub) {
217
189
TORCH_CHECK (input.is_contiguous ());
218
190
TORCH_CHECK (out.is_contiguous ());
219
191
@@ -228,26 +200,23 @@ void dynamic_per_token_scaled_fp8_quant(
228
200
auto stream = at::xpu::getCurrentXPUStream ().queue ();
229
201
VLLM_DISPATCH_FLOATING_TYPES (
230
202
input.scalar_type (),
231
- " dynamic_per_token_scaled_fp8_quant_kernel_scalar_type" ,
232
- [&] {
203
+ " dynamic_per_token_scaled_fp8_quant_kernel_scalar_type" , [&] {
233
204
VLLM_DISPATCH_FP8_TYPES (
234
205
out.scalar_type (),
235
- " dynamic_per_token_scaled_fp8_quant_kernel_fp8_type" ,
236
- [&] {
206
+ " dynamic_per_token_scaled_fp8_quant_kernel_fp8_type" , [&] {
237
207
// Launch the kernel
238
208
stream
239
209
.submit ([&](sycl::handler& cgh) {
240
- auto kernel = vllm::dynamic_per_token_scaled_fp8_quant_kernel<
241
- scalar_t ,
242
- fp8_t >(
243
- out.data_ptr <fp8_t >(),
244
- scales.data_ptr <float >(),
245
- input.data_ptr <scalar_t >(),
246
- scale_ub.has_value () ? scale_ub->data_ptr <float >()
247
- : nullptr ,
248
- hidden_size);
249
- cgh.parallel_for (
250
- sycl::nd_range<1 >(grid * block, block), kernel);
210
+ auto kernel =
211
+ vllm::dynamic_per_token_scaled_fp8_quant_kernel<
212
+ scalar_t , fp8_t >(
213
+ out.data_ptr <fp8_t >(), scales.data_ptr <float >(),
214
+ input.data_ptr <scalar_t >(),
215
+ scale_ub.has_value () ? scale_ub->data_ptr <float >()
216
+ : nullptr ,
217
+ hidden_size);
218
+ cgh.parallel_for (sycl::nd_range<1 >(grid * block, block),
219
+ kernel);
251
220
})
252
221
.wait ();
253
222
});
0 commit comments