Skip to content

Commit 1b96119

Browse files
baodiizufangzhu
authored andcommitted
add fp8 quant kernels
Signed-off-by: baodii <[email protected]>
1 parent 6f1df4f commit 1b96119

File tree

10 files changed

+703
-0
lines changed

10 files changed

+703
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
151151
"csrc/xpu/activation.cpp"
152152
"csrc/xpu/pos_encoding_kernels.cpp"
153153
"csrc/xpu/torch_bindings.cpp"
154+
"csrc/xpu/quantization/fp8/fp8_quant.cpp"
154155
)
155156
include_directories("/usr/include")
156157
set(CMPLR_ROOT $ENV{CMPLR_ROOT})

csrc/xpu/ops.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,13 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
2626
torch::Tensor& slot_mapping,
2727
const std::string& kv_cache_dtype,
2828
torch::Tensor& k_scale, torch::Tensor& v_scale);
29+
30+
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
31+
torch::Tensor const& scale);
32+
33+
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
34+
torch::Tensor& scale);
35+
36+
void dynamic_per_token_scaled_fp8_quant(
37+
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales,
38+
std::optional<at::Tensor> const& scale_ub);
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/DeviceGuard.h>
3+
#include <ATen/xpu/XPUContext.h>
4+
5+
#include <sycl/sycl.hpp>
6+
7+
#include "xpu/dispatch_utils.h"
8+
9+
#include "fp8_quant.h"
10+
#include "utils.h"
11+
12+
13+
namespace vllm {
14+
15+
template <typename scalar_t, typename fp8_type>
16+
class scaled_fp8_quant_kernel {
17+
private:
18+
fp8_type* out;
19+
const scalar_t* input;
20+
const float* scale;
21+
int64_t num_elems;
22+
23+
public:
24+
scaled_fp8_quant_kernel(
25+
fp8_type* out_,
26+
const scalar_t* input_,
27+
const float* scale_,
28+
int64_t num_elems_)
29+
: out(out_), input(input_), scale(scale_), num_elems(num_elems_) {}
30+
void operator()(sycl::nd_item<1> item) const {
31+
int tid = item.get_global_linear_id();
32+
33+
// Invert the scale so that we can use multiplications to avoid expensive
34+
// division.
35+
const float inverted_scale = 1.0f / (*scale);
36+
scaled_fp8_conversion_vec<scalar_t, true>(
37+
out,
38+
input,
39+
inverted_scale,
40+
num_elems,
41+
tid,
42+
item.get_local_range(0) * item.get_group_range(0));
43+
}
44+
};
45+
46+
template <typename scalar_t, typename fp8_type>
47+
class dynamic_per_token_scaled_fp8_quant_kernel {
48+
private:
49+
fp8_type* out;
50+
float* scale;
51+
scalar_t const* input;
52+
float const* scale_ub;
53+
const int hidden_size;
54+
55+
public:
56+
dynamic_per_token_scaled_fp8_quant_kernel(
57+
fp8_type* out_,
58+
float* scale_,
59+
scalar_t const* input_,
60+
float const* scale_ub_,
61+
const int hidden_size_)
62+
: out(out_),
63+
scale(scale_),
64+
input(input_),
65+
scale_ub(scale_ub_),
66+
hidden_size(hidden_size_) {}
67+
68+
void operator()(sycl::nd_item<1> item) const {
69+
int const tid = item.get_local_id(0);
70+
int const token_idx = item.get_group(0);
71+
72+
// sycl::ext::oneapi::experimental::printf(
73+
// "token_idx: %d, tid: %d, hidden_size: %d, group_range: %d\n",
74+
// token_idx,
75+
// tid,
76+
// hidden_size,
77+
// item.get_local_range(0));
78+
79+
// Use int64 to avoid overflowing an int32 when calculating this offset
80+
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
81+
scalar_t const* token_input = &input[offset];
82+
fp8_type* token_output = &out[offset];
83+
84+
// For vectorization, token_input and token_output pointers need to be
85+
// aligned at 8-byte and 4-byte addresses respectively.
86+
bool const can_vectorize = hidden_size % 4 == 0;
87+
88+
float absmax_val = 0.0f;
89+
if (can_vectorize) {
90+
absmax_val = thread_max_vec(
91+
token_input, hidden_size, tid, item.get_local_range(0));
92+
} else {
93+
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) {
94+
float const x = static_cast<float>(token_input[i]);
95+
absmax_val = sycl::max(absmax_val, sycl::fabs(x));
96+
}
97+
}
98+
99+
float const block_absmax_val_maybe = sycl::reduce_over_group(
100+
item.get_group(), absmax_val, sycl::maximum<float>());
101+
// __shared__ float token_scale;
102+
auto& token_scale =
103+
*sycl::ext::oneapi::group_local_memory_for_overwrite<float[1]>(
104+
item.get_group());
105+
if (tid == 0) {
106+
if (scale_ub) {
107+
token_scale[0] = sycl::min(block_absmax_val_maybe, *scale_ub);
108+
} else {
109+
token_scale[0] = block_absmax_val_maybe;
110+
}
111+
// token scale computation
112+
token_scale[0] = sycl::max(
113+
token_scale[0] / quant_type_max_v<fp8_type>,
114+
min_scaling_factor<fp8_type>::val());
115+
scale[token_idx] = token_scale[0];
116+
}
117+
group_barrier(item.get_group());
118+
119+
// Note that we don't use inverted scales so we can match FBGemm impl.
120+
const float inverted_scale = 1.0f / (token_scale[0]);
121+
if (can_vectorize) {
122+
scaled_fp8_conversion_vec<scalar_t, true>(
123+
token_output,
124+
token_input,
125+
inverted_scale,
126+
hidden_size,
127+
tid,
128+
item.get_local_range(0));
129+
} else {
130+
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) {
131+
token_output[i] = scaled_fp8_conversion<true, fp8_type>(
132+
static_cast<float>(token_input[i]), inverted_scale, tid, token_idx);
133+
}
134+
}
135+
}
136+
};
137+
138+
} // namespace vllm
139+
140+
void static_scaled_fp8_quant(
141+
torch::Tensor& out, // [..., d]
142+
torch::Tensor const& input, // [..., d]
143+
torch::Tensor const& scale) // [1]
144+
{
145+
int64_t num_tokens = input.numel() / input.size(-1);
146+
int64_t num_elems = input.numel();
147+
sycl::range<1> grid(num_tokens);
148+
sycl::range<1> block(1024);
149+
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
150+
at::DeviceGuard device_guard(curDevice);
151+
152+
auto stream = at::xpu::getCurrentXPUStream().queue();
153+
// TODO: change name?
154+
VLLM_DISPATCH_FLOATING_TYPES(
155+
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
156+
VLLM_DISPATCH_FP8_TYPES(
157+
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
158+
// Launch the kernel
159+
stream.submit([&](sycl::handler& cgh) {
160+
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>(
161+
out.data_ptr<fp8_t>(),
162+
input.data_ptr<scalar_t>(),
163+
scale.data_ptr<float>(),
164+
num_elems);
165+
cgh.parallel_for(
166+
sycl::nd_range<1>(grid * block, block), kernel);
167+
});
168+
});
169+
});
170+
}
171+
172+
void dynamic_scaled_fp8_quant(
173+
torch::Tensor& out, // [..., d]
174+
torch::Tensor const& input, // [..., d]
175+
torch::Tensor& scale) // [1]
176+
{
177+
int64_t num_tokens = input.numel() / input.size(-1);
178+
int64_t num_elems = input.numel();
179+
sycl::range<1> grid(num_tokens);
180+
sycl::range<1> block(1024);
181+
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
182+
at::DeviceGuard device_guard(curDevice);
183+
184+
auto stream = at::xpu::getCurrentXPUStream().queue();
185+
VLLM_DISPATCH_FLOATING_TYPES(
186+
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
187+
VLLM_DISPATCH_FP8_TYPES(
188+
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
189+
// Launch the kernel
190+
stream.submit([&](sycl::handler& cgh) {
191+
auto max_reduce_kernel =
192+
vllm::segmented_max_reduction<scalar_t, fp8_t>(
193+
scale.data_ptr<float>(),
194+
input.data_ptr<scalar_t>(),
195+
num_elems);
196+
cgh.parallel_for(
197+
sycl::nd_range<1>(grid * block, block), max_reduce_kernel);
198+
});
199+
stream.submit([&](sycl::handler& cgh) {
200+
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>(
201+
out.data_ptr<fp8_t>(),
202+
input.data_ptr<scalar_t>(),
203+
scale.data_ptr<float>(),
204+
num_elems);
205+
cgh.parallel_for(
206+
sycl::nd_range<1>(grid * block, block), kernel);
207+
});
208+
});
209+
});
210+
}
211+
212+
void dynamic_per_token_scaled_fp8_quant(
213+
torch::Tensor& out, // [..., d]
214+
torch::Tensor const& input, // [..., d]
215+
torch::Tensor& scales,
216+
std::optional<at::Tensor> const& scale_ub) {
217+
TORCH_CHECK(input.is_contiguous());
218+
TORCH_CHECK(out.is_contiguous());
219+
220+
int const hidden_size = input.size(-1);
221+
int const num_tokens = input.numel() / hidden_size;
222+
sycl::range<1> grid(num_tokens);
223+
sycl::range<1> block(std::min(hidden_size, 1024));
224+
225+
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
226+
at::DeviceGuard device_guard(curDevice);
227+
228+
auto stream = at::xpu::getCurrentXPUStream().queue();
229+
VLLM_DISPATCH_FLOATING_TYPES(
230+
input.scalar_type(),
231+
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type",
232+
[&] {
233+
VLLM_DISPATCH_FP8_TYPES(
234+
out.scalar_type(),
235+
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type",
236+
[&] {
237+
// Launch the kernel
238+
stream
239+
.submit([&](sycl::handler& cgh) {
240+
auto kernel = vllm::dynamic_per_token_scaled_fp8_quant_kernel<
241+
scalar_t,
242+
fp8_t>(
243+
out.data_ptr<fp8_t>(),
244+
scales.data_ptr<float>(),
245+
input.data_ptr<scalar_t>(),
246+
scale_ub.has_value() ? scale_ub->data_ptr<float>()
247+
: nullptr,
248+
hidden_size);
249+
cgh.parallel_for(
250+
sycl::nd_range<1>(grid * block, block), kernel);
251+
})
252+
.wait();
253+
});
254+
});
255+
}

0 commit comments

Comments
 (0)