-
Notifications
You must be signed in to change notification settings - Fork 7
add fp8 quantization kernels #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
baodii
wants to merge
6
commits into
vllm-project:main
Choose a base branch
from
baodii:baodi/fp8_kernels
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+740
−36
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
1b96119
add fp8 quant kernels
baodii 3a318ff
add fp8_e5m2 support and fixing UT
baodii dcf2d5e
add per-token quanzation
baodii 56ebb81
change per-token UT assert_close by percentage of mismatch since fp8 …
baodii 5bbb0cc
merge redundant files
zufangzhu 059a372
remove redefine struct
zufangzhu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/DeviceGuard.h> | ||
#include <ATen/xpu/XPUContext.h> | ||
|
||
#include <sycl/sycl.hpp> | ||
|
||
#include "xpu/dispatch_utils.h" | ||
#include "xpu/ops.h" | ||
|
||
#include "fp8_quant.h" | ||
#include "quant_utils.h" | ||
|
||
namespace vllm { | ||
|
||
template <typename scalar_t, typename fp8_type> | ||
class scaled_fp8_quant_kernel { | ||
private: | ||
fp8_type* out; | ||
const scalar_t* input; | ||
const float* scale; | ||
int64_t num_elems; | ||
|
||
public: | ||
scaled_fp8_quant_kernel(fp8_type* out_, const scalar_t* input_, | ||
const float* scale_, int64_t num_elems_) | ||
: out(out_), input(input_), scale(scale_), num_elems(num_elems_) {} | ||
void operator()(sycl::nd_item<1> item) const { | ||
int tid = item.get_global_linear_id(); | ||
|
||
// Invert the scale so that we can use multiplications to avoid expensive | ||
// division. | ||
const float inverted_scale = 1.0f / (*scale); | ||
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale}; | ||
fp8::scaled_convert_vec(input, out, num_elems, tid, | ||
item.get_local_range(0) * item.get_group_range(0), | ||
op); | ||
} | ||
}; | ||
|
||
template <typename scalar_t, typename fp8_type> | ||
class dynamic_per_token_scaled_fp8_quant_kernel { | ||
private: | ||
fp8_type* out; | ||
float* scale; | ||
scalar_t const* input; | ||
float const* scale_ub; | ||
const int hidden_size; | ||
|
||
public: | ||
dynamic_per_token_scaled_fp8_quant_kernel(fp8_type* out_, float* scale_, | ||
scalar_t const* input_, | ||
float const* scale_ub_, | ||
const int hidden_size_) | ||
: out(out_), | ||
scale(scale_), | ||
input(input_), | ||
scale_ub(scale_ub_), | ||
hidden_size(hidden_size_) {} | ||
|
||
void operator()(sycl::nd_item<1> item) const { | ||
int const tid = item.get_local_id(0); | ||
int const token_idx = item.get_group(0); | ||
|
||
// Use int64 to avoid overflowing an int32 when calculating this offset | ||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size; | ||
scalar_t const* token_input = &input[offset]; | ||
fp8_type* token_output = &out[offset]; | ||
|
||
// For vectorization, token_input and token_output pointers need to be | ||
// aligned at 8-byte and 4-byte addresses respectively. | ||
bool const can_vectorize = hidden_size % 4 == 0; | ||
|
||
float absmax_val = 0.0f; | ||
if (can_vectorize) { | ||
absmax_val = thread_max_vec(token_input, hidden_size, tid, | ||
item.get_local_range(0)); | ||
} else { | ||
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) { | ||
float const x = static_cast<float>(token_input[i]); | ||
absmax_val = sycl::max(absmax_val, sycl::fabs(x)); | ||
} | ||
} | ||
|
||
float const block_absmax_val_maybe = sycl::reduce_over_group( | ||
item.get_group(), absmax_val, sycl::maximum<float>()); | ||
// __shared__ float token_scale; | ||
auto& token_scale = | ||
*sycl::ext::oneapi::group_local_memory_for_overwrite<float[1]>( | ||
item.get_group()); | ||
if (tid == 0) { | ||
if (scale_ub) { | ||
token_scale[0] = sycl::min(block_absmax_val_maybe, *scale_ub); | ||
} else { | ||
token_scale[0] = block_absmax_val_maybe; | ||
} | ||
// token scale computation | ||
token_scale[0] = | ||
sycl::max(token_scale[0] / fp8::quant_type_max_v<fp8_type>, | ||
fp8::min_scaling_factor<fp8_type>::val()); | ||
scale[token_idx] = token_scale[0]; | ||
} | ||
group_barrier(item.get_group()); | ||
|
||
// Note that we don't use inverted scales so we can match FBGemm impl. | ||
const float inverted_scale = 1.0f / (token_scale[0]); | ||
if (can_vectorize) { | ||
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale}; | ||
fp8::scaled_convert_vec(token_input, token_output, hidden_size, tid, | ||
item.get_local_range(0), op); | ||
} else { | ||
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) { | ||
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale}; | ||
op(token_output[i], token_input[i]); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace vllm | ||
|
||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] | ||
torch::Tensor const& input, // [..., d] | ||
torch::Tensor const& scale) // [1] | ||
{ | ||
int64_t num_tokens = input.numel() / input.size(-1); | ||
int64_t num_elems = input.numel(); | ||
sycl::range<1> grid(num_tokens); | ||
sycl::range<1> block(1024); | ||
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); | ||
at::DeviceGuard device_guard(curDevice); | ||
|
||
auto stream = at::xpu::getCurrentXPUStream().queue(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { | ||
VLLM_DISPATCH_FP8_TYPES( | ||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { | ||
// Launch the kernel | ||
stream.submit([&](sycl::handler& cgh) { | ||
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>( | ||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), | ||
scale.data_ptr<float>(), num_elems); | ||
cgh.parallel_for(sycl::nd_range<1>(grid * block, block), | ||
kernel); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] | ||
torch::Tensor const& input, // [..., d] | ||
torch::Tensor& scale) // [1] | ||
{ | ||
int64_t num_tokens = input.numel() / input.size(-1); | ||
int64_t num_elems = input.numel(); | ||
sycl::range<1> grid(num_tokens); | ||
sycl::range<1> block(1024); | ||
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); | ||
at::DeviceGuard device_guard(curDevice); | ||
|
||
auto stream = at::xpu::getCurrentXPUStream().queue(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { | ||
VLLM_DISPATCH_FP8_TYPES( | ||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { | ||
// Launch the kernel | ||
stream.submit([&](sycl::handler& cgh) { | ||
auto max_reduce_kernel = | ||
vllm::segmented_max_reduction<scalar_t, fp8_t>( | ||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), | ||
num_elems); | ||
cgh.parallel_for(sycl::nd_range<1>(grid * block, block), | ||
max_reduce_kernel); | ||
}); | ||
stream.submit([&](sycl::handler& cgh) { | ||
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>( | ||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), | ||
scale.data_ptr<float>(), num_elems); | ||
cgh.parallel_for(sycl::nd_range<1>(grid * block, block), | ||
kernel); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
void dynamic_per_token_scaled_fp8_quant( | ||
torch::Tensor& out, // [..., d] | ||
torch::Tensor const& input, // [..., d] | ||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) { | ||
TORCH_CHECK(input.is_contiguous()); | ||
TORCH_CHECK(out.is_contiguous()); | ||
|
||
int const hidden_size = input.size(-1); | ||
int const num_tokens = input.numel() / hidden_size; | ||
sycl::range<1> grid(num_tokens); | ||
sycl::range<1> block(std::min(hidden_size, 1024)); | ||
|
||
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); | ||
at::DeviceGuard device_guard(curDevice); | ||
|
||
auto stream = at::xpu::getCurrentXPUStream().queue(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
input.scalar_type(), | ||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] { | ||
VLLM_DISPATCH_FP8_TYPES( | ||
out.scalar_type(), | ||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { | ||
// Launch the kernel | ||
stream | ||
.submit([&](sycl::handler& cgh) { | ||
auto kernel = | ||
vllm::dynamic_per_token_scaled_fp8_quant_kernel< | ||
scalar_t, fp8_t>( | ||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(), | ||
input.data_ptr<scalar_t>(), | ||
scale_ub.has_value() ? scale_ub->data_ptr<float>() | ||
: nullptr, | ||
hidden_size); | ||
cgh.parallel_for(sycl::nd_range<1>(grid * block, block), | ||
kernel); | ||
}) | ||
.wait(); | ||
}); | ||
}); | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
#pragma once | ||
|
||
#include <sycl/sycl.hpp> | ||
#include <cmath> | ||
|
||
#include <c10/util/Float8_e4m3fn.h> | ||
#include <c10/util/Float8_e5m2.h> | ||
|
||
#include "quant_utils.h" | ||
|
||
using namespace at; | ||
|
||
namespace vllm { | ||
|
||
template <typename scalar_t> | ||
inline float thread_max_vec(scalar_t const* input, int64_t const num_elems, | ||
int const tid, int const step) { | ||
// Vectorized input/output to better utilize memory bandwidth. | ||
using vec4_t = fp8::vec4_t<scalar_t>; | ||
vec4_t const* vectorized_in = reinterpret_cast<vec4_t const*>(input); | ||
|
||
int64_t const num_vec_elems = num_elems >> 2; | ||
float absmax_val = 0.0f; | ||
|
||
#pragma unroll 4 | ||
for (int64_t i = tid; i < num_vec_elems; i += step) { | ||
vec4_t in_vec = vectorized_in[i]; | ||
absmax_val = | ||
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.x))); | ||
absmax_val = | ||
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.y))); | ||
absmax_val = | ||
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.z))); | ||
absmax_val = | ||
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.w))); | ||
} | ||
|
||
// Handle the remaining elements if num_elems is not divisible by 4 | ||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { | ||
absmax_val = | ||
sycl::max(absmax_val, sycl::fabs(static_cast<float>(input[i]))); | ||
} | ||
|
||
return absmax_val; | ||
} | ||
|
||
template <typename scalar_t, typename fp8_type> | ||
class segmented_max_reduction { | ||
private: | ||
float* scale; | ||
const scalar_t* input; | ||
int64_t num_elems; | ||
|
||
public: | ||
segmented_max_reduction(float* scale_, const scalar_t* input_, | ||
int64_t num_elems_) | ||
: scale(scale_), input(input_), num_elems(num_elems_) {} | ||
void operator()(sycl::nd_item<1> item) const { | ||
auto& cache = | ||
*sycl::ext::oneapi::group_local_memory_for_overwrite<float[1024]>( | ||
item.get_group()); | ||
int64_t i = item.get_global_linear_id(); | ||
|
||
// First store maximum for all values processes by | ||
// the current thread in cache[item.get_local_id(0)] | ||
float tmp = 0.0; | ||
while (i < num_elems) { | ||
float x = static_cast<float>(input[i]); | ||
tmp = sycl::max(tmp, sycl::fabs(x)); | ||
i += item.get_local_range(0) * item.get_group_range(0); | ||
} | ||
cache[item.get_local_id(0)] = tmp; | ||
|
||
group_barrier(item.get_group()); | ||
|
||
// Now perform parallel reduction within the thread block | ||
int ib = item.get_local_range(0) / 2; | ||
while (ib != 0) { | ||
if (item.get_local_id(0) < ib && | ||
cache[item.get_local_id(0) + ib] > cache[item.get_local_id(0)]) { | ||
cache[item.get_local_id(0)] = cache[item.get_local_id(0) + ib]; | ||
} | ||
group_barrier(item.get_group()); | ||
ib /= 2; | ||
} | ||
// Finally, since cache[0] contains the maximum for this thread block, | ||
// atomically write the max to the target location | ||
// TODO: Do we need if statement? | ||
if (item.get_local_id(0) == 0) { | ||
using atomic_t = | ||
sycl::atomic_ref<float, sycl::memory_order::relaxed, | ||
sycl::memory_scope::device, | ||
sycl::access::address_space::global_space>; | ||
atomic_t atomic_max(*scale); | ||
atomic_max.fetch_max(cache[0] / fp8::quant_type_max_v<fp8_type>); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace vllm |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe can just call
auto& queue = vllm::xpu::vllmGetQueue()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel should not depend on vllm api or libs