Skip to content

add fp8 quantization kernels #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
"csrc/xpu/activation.cpp"
"csrc/xpu/pos_encoding_kernels.cpp"
"csrc/xpu/torch_bindings.cpp"
"csrc/xpu/quantization/fp8/fp8_quant.cpp"
)
include_directories("/usr/include")
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
Expand Down
2 changes: 1 addition & 1 deletion csrc/xpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <string>

#include "dispatch_utils.h"
#include "quantization/fp8/quant_utils.hpp"
#include "quantization/fp8/quant_utils.h"
#include "utils.h"

namespace vllm {
Expand Down
6 changes: 4 additions & 2 deletions csrc/xpu/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)

// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
Expand Down
10 changes: 10 additions & 0 deletions csrc/xpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,13 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);

void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);

void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales,
std::optional<at::Tensor> const& scale_ub);
224 changes: 224 additions & 0 deletions csrc/xpu/quantization/fp8/fp8_quant.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#include <ATen/ATen.h>
#include <ATen/DeviceGuard.h>
#include <ATen/xpu/XPUContext.h>

#include <sycl/sycl.hpp>

#include "xpu/dispatch_utils.h"
#include "xpu/ops.h"

#include "fp8_quant.h"
#include "quant_utils.h"

namespace vllm {

template <typename scalar_t, typename fp8_type>
class scaled_fp8_quant_kernel {
private:
fp8_type* out;
const scalar_t* input;
const float* scale;
int64_t num_elems;

public:
scaled_fp8_quant_kernel(fp8_type* out_, const scalar_t* input_,
const float* scale_, int64_t num_elems_)
: out(out_), input(input_), scale(scale_), num_elems(num_elems_) {}
void operator()(sycl::nd_item<1> item) const {
int tid = item.get_global_linear_id();

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale};
fp8::scaled_convert_vec(input, out, num_elems, tid,
item.get_local_range(0) * item.get_group_range(0),
op);
}
};

template <typename scalar_t, typename fp8_type>
class dynamic_per_token_scaled_fp8_quant_kernel {
private:
fp8_type* out;
float* scale;
scalar_t const* input;
float const* scale_ub;
const int hidden_size;

public:
dynamic_per_token_scaled_fp8_quant_kernel(fp8_type* out_, float* scale_,
scalar_t const* input_,
float const* scale_ub_,
const int hidden_size_)
: out(out_),
scale(scale_),
input(input_),
scale_ub(scale_ub_),
hidden_size(hidden_size_) {}

void operator()(sycl::nd_item<1> item) const {
int const tid = item.get_local_id(0);
int const token_idx = item.get_group(0);

// Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* token_input = &input[offset];
fp8_type* token_output = &out[offset];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;

float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid,
item.get_local_range(0));
} else {
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) {
float const x = static_cast<float>(token_input[i]);
absmax_val = sycl::max(absmax_val, sycl::fabs(x));
}
}

float const block_absmax_val_maybe = sycl::reduce_over_group(
item.get_group(), absmax_val, sycl::maximum<float>());
// __shared__ float token_scale;
auto& token_scale =
*sycl::ext::oneapi::group_local_memory_for_overwrite<float[1]>(
item.get_group());
if (tid == 0) {
if (scale_ub) {
token_scale[0] = sycl::min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale[0] = block_absmax_val_maybe;
}
// token scale computation
token_scale[0] =
sycl::max(token_scale[0] / fp8::quant_type_max_v<fp8_type>,
fp8::min_scaling_factor<fp8_type>::val());
scale[token_idx] = token_scale[0];
}
group_barrier(item.get_group());

// Note that we don't use inverted scales so we can match FBGemm impl.
const float inverted_scale = 1.0f / (token_scale[0]);
if (can_vectorize) {
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale};
fp8::scaled_convert_vec(token_input, token_output, hidden_size, tid,
item.get_local_range(0), op);
} else {
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) {
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale};
op(token_output[i], token_input[i]);
}
}
}
};

} // namespace vllm

void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
sycl::range<1> grid(num_tokens);
sycl::range<1> block(1024);
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
at::DeviceGuard device_guard(curDevice);

auto stream = at::xpu::getCurrentXPUStream().queue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe can just call auto& queue = vllm::xpu::vllmGetQueue()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel should not depend on vllm api or libs

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
// Launch the kernel
stream.submit([&](sycl::handler& cgh) {
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
kernel);
});
});
});
}

void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
sycl::range<1> grid(num_tokens);
sycl::range<1> block(1024);
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
at::DeviceGuard device_guard(curDevice);

auto stream = at::xpu::getCurrentXPUStream().queue();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
// Launch the kernel
stream.submit([&](sycl::handler& cgh) {
auto max_reduce_kernel =
vllm::segmented_max_reduction<scalar_t, fp8_t>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
num_elems);
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
max_reduce_kernel);
});
stream.submit([&](sycl::handler& cgh) {
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
kernel);
});
});
});
}

void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
sycl::range<1> grid(num_tokens);
sycl::range<1> block(std::min(hidden_size, 1024));

at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
at::DeviceGuard device_guard(curDevice);

auto stream = at::xpu::getCurrentXPUStream().queue();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(),
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
// Launch the kernel
stream
.submit([&](sycl::handler& cgh) {
auto kernel =
vllm::dynamic_per_token_scaled_fp8_quant_kernel<
scalar_t, fp8_t>(
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
kernel);
})
.wait();
});
});
}
100 changes: 100 additions & 0 deletions csrc/xpu/quantization/fp8/fp8_quant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#pragma once

#include <sycl/sycl.hpp>
#include <cmath>

#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>

#include "quant_utils.h"

using namespace at;

namespace vllm {

template <typename scalar_t>
inline float thread_max_vec(scalar_t const* input, int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
using vec4_t = fp8::vec4_t<scalar_t>;
vec4_t const* vectorized_in = reinterpret_cast<vec4_t const*>(input);

int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t in_vec = vectorized_in[i];
absmax_val =
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.x)));
absmax_val =
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.y)));
absmax_val =
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.z)));
absmax_val =
sycl::max(absmax_val, sycl::fabs(static_cast<float>(in_vec.w)));
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val =
sycl::max(absmax_val, sycl::fabs(static_cast<float>(input[i])));
}

return absmax_val;
}

template <typename scalar_t, typename fp8_type>
class segmented_max_reduction {
private:
float* scale;
const scalar_t* input;
int64_t num_elems;

public:
segmented_max_reduction(float* scale_, const scalar_t* input_,
int64_t num_elems_)
: scale(scale_), input(input_), num_elems(num_elems_) {}
void operator()(sycl::nd_item<1> item) const {
auto& cache =
*sycl::ext::oneapi::group_local_memory_for_overwrite<float[1024]>(
item.get_group());
int64_t i = item.get_global_linear_id();

// First store maximum for all values processes by
// the current thread in cache[item.get_local_id(0)]
float tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = sycl::max(tmp, sycl::fabs(x));
i += item.get_local_range(0) * item.get_group_range(0);
}
cache[item.get_local_id(0)] = tmp;

group_barrier(item.get_group());

// Now perform parallel reduction within the thread block
int ib = item.get_local_range(0) / 2;
while (ib != 0) {
if (item.get_local_id(0) < ib &&
cache[item.get_local_id(0) + ib] > cache[item.get_local_id(0)]) {
cache[item.get_local_id(0)] = cache[item.get_local_id(0) + ib];
}
group_barrier(item.get_group());
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
// TODO: Do we need if statement?
if (item.get_local_id(0) == 0) {
using atomic_t =
sycl::atomic_ref<float, sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::global_space>;
atomic_t atomic_max(*scale);
atomic_max.fetch_max(cache[0] / fp8::quant_type_max_v<fp8_type>);
}
}
};

} // namespace vllm
Loading