diff --git a/CMakeLists.txt b/CMakeLists.txt index e3c09e3..1f93e22 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") "csrc/xpu/activation.cpp" "csrc/xpu/pos_encoding_kernels.cpp" "csrc/xpu/torch_bindings.cpp" + "csrc/xpu/quantization/fp8/fp8_quant.cpp" ) include_directories("/usr/include") set(CMPLR_ROOT $ENV{CMPLR_ROOT}) diff --git a/csrc/xpu/cache.cpp b/csrc/xpu/cache.cpp index ab226ef..b802aeb 100644 --- a/csrc/xpu/cache.cpp +++ b/csrc/xpu/cache.cpp @@ -5,7 +5,7 @@ #include #include "dispatch_utils.h" -#include "quantization/fp8/quant_utils.hpp" +#include "quantization/fp8/quant_utils.h" #include "utils.h" namespace vllm { diff --git a/csrc/xpu/dispatch_utils.h b/csrc/xpu/dispatch_utils.h index a9455e1..1849ed2 100644 --- a/csrc/xpu/dispatch_utils.h +++ b/csrc/xpu/dispatch_utils.h @@ -21,11 +21,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ - AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) // When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'. diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index 8dd9b71..77ebbaf 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -26,3 +26,13 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale); + +void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); + +void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor& scale); + +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales, + std::optional const& scale_ub); diff --git a/csrc/xpu/quantization/fp8/fp8_quant.cpp b/csrc/xpu/quantization/fp8/fp8_quant.cpp new file mode 100644 index 0000000..fae2438 --- /dev/null +++ b/csrc/xpu/quantization/fp8/fp8_quant.cpp @@ -0,0 +1,224 @@ +#include +#include +#include + +#include + +#include "xpu/dispatch_utils.h" +#include "xpu/ops.h" + +#include "fp8_quant.h" +#include "quant_utils.h" + +namespace vllm { + +template +class scaled_fp8_quant_kernel { + private: + fp8_type* out; + const scalar_t* input; + const float* scale; + int64_t num_elems; + + public: + scaled_fp8_quant_kernel(fp8_type* out_, const scalar_t* input_, + const float* scale_, int64_t num_elems_) + : out(out_), input(input_), scale(scale_), num_elems(num_elems_) {} + void operator()(sycl::nd_item<1> item) const { + int tid = item.get_global_linear_id(); + + // Invert the scale so that we can use multiplications to avoid expensive + // division. + const float inverted_scale = 1.0f / (*scale); + fp8::ConvertWithScaleOp op{inverted_scale}; + fp8::scaled_convert_vec(input, out, num_elems, tid, + item.get_local_range(0) * item.get_group_range(0), + op); + } +}; + +template +class dynamic_per_token_scaled_fp8_quant_kernel { + private: + fp8_type* out; + float* scale; + scalar_t const* input; + float const* scale_ub; + const int hidden_size; + + public: + dynamic_per_token_scaled_fp8_quant_kernel(fp8_type* out_, float* scale_, + scalar_t const* input_, + float const* scale_ub_, + const int hidden_size_) + : out(out_), + scale(scale_), + input(input_), + scale_ub(scale_ub_), + hidden_size(hidden_size_) {} + + void operator()(sycl::nd_item<1> item) const { + int const tid = item.get_local_id(0); + int const token_idx = item.get_group(0); + + // Use int64 to avoid overflowing an int32 when calculating this offset + int64_t offset = static_cast(token_idx) * hidden_size; + scalar_t const* token_input = &input[offset]; + fp8_type* token_output = &out[offset]; + + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + float absmax_val = 0.0f; + if (can_vectorize) { + absmax_val = thread_max_vec(token_input, hidden_size, tid, + item.get_local_range(0)); + } else { + for (int i = tid; i < hidden_size; i += item.get_local_range(0)) { + float const x = static_cast(token_input[i]); + absmax_val = sycl::max(absmax_val, sycl::fabs(x)); + } + } + + float const block_absmax_val_maybe = sycl::reduce_over_group( + item.get_group(), absmax_val, sycl::maximum()); + // __shared__ float token_scale; + auto& token_scale = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + item.get_group()); + if (tid == 0) { + if (scale_ub) { + token_scale[0] = sycl::min(block_absmax_val_maybe, *scale_ub); + } else { + token_scale[0] = block_absmax_val_maybe; + } + // token scale computation + token_scale[0] = + sycl::max(token_scale[0] / fp8::quant_type_max_v, + fp8::min_scaling_factor::val()); + scale[token_idx] = token_scale[0]; + } + group_barrier(item.get_group()); + + // Note that we don't use inverted scales so we can match FBGemm impl. + const float inverted_scale = 1.0f / (token_scale[0]); + if (can_vectorize) { + fp8::ConvertWithScaleOp op{inverted_scale}; + fp8::scaled_convert_vec(token_input, token_output, hidden_size, tid, + item.get_local_range(0), op); + } else { + for (int i = tid; i < hidden_size; i += item.get_local_range(0)) { + fp8::ConvertWithScaleOp op{inverted_scale}; + op(token_output[i], token_input[i]); + } + } + } +}; + +} // namespace vllm + +void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor const& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + sycl::range<1> grid(num_tokens); + sycl::range<1> block(1024); + at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); + at::DeviceGuard device_guard(curDevice); + + auto stream = at::xpu::getCurrentXPUStream().queue(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { + // Launch the kernel + stream.submit([&](sycl::handler& cgh) { + auto kernel = vllm::scaled_fp8_quant_kernel( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + cgh.parallel_for(sycl::nd_range<1>(grid * block, block), + kernel); + }); + }); + }); +} + +void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + sycl::range<1> grid(num_tokens); + sycl::range<1> block(1024); + at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); + at::DeviceGuard device_guard(curDevice); + + auto stream = at::xpu::getCurrentXPUStream().queue(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { + // Launch the kernel + stream.submit([&](sycl::handler& cgh) { + auto max_reduce_kernel = + vllm::segmented_max_reduction( + scale.data_ptr(), input.data_ptr(), + num_elems); + cgh.parallel_for(sycl::nd_range<1>(grid * block, block), + max_reduce_kernel); + }); + stream.submit([&](sycl::handler& cgh) { + auto kernel = vllm::scaled_fp8_quant_kernel( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + cgh.parallel_for(sycl::nd_range<1>(grid * block, block), + kernel); + }); + }); + }); +} + +void dynamic_per_token_scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor const& input, // [..., d] + torch::Tensor& scales, std::optional const& scale_ub) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + sycl::range<1> grid(num_tokens); + sycl::range<1> block(std::min(hidden_size, 1024)); + + at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); + at::DeviceGuard device_guard(curDevice); + + auto stream = at::xpu::getCurrentXPUStream().queue(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), + "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { + // Launch the kernel + stream + .submit([&](sycl::handler& cgh) { + auto kernel = + vllm::dynamic_per_token_scaled_fp8_quant_kernel< + scalar_t, fp8_t>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + hidden_size); + cgh.parallel_for(sycl::nd_range<1>(grid * block, block), + kernel); + }) + .wait(); + }); + }); +} diff --git a/csrc/xpu/quantization/fp8/fp8_quant.h b/csrc/xpu/quantization/fp8/fp8_quant.h new file mode 100644 index 0000000..e11d5e0 --- /dev/null +++ b/csrc/xpu/quantization/fp8/fp8_quant.h @@ -0,0 +1,100 @@ +#pragma once + +#include +#include + +#include +#include + +#include "quant_utils.h" + +using namespace at; + +namespace vllm { + +template +inline float thread_max_vec(scalar_t const* input, int64_t const num_elems, + int const tid, int const step) { + // Vectorized input/output to better utilize memory bandwidth. + using vec4_t = fp8::vec4_t; + vec4_t const* vectorized_in = reinterpret_cast(input); + + int64_t const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = + sycl::max(absmax_val, sycl::fabs(static_cast(in_vec.x))); + absmax_val = + sycl::max(absmax_val, sycl::fabs(static_cast(in_vec.y))); + absmax_val = + sycl::max(absmax_val, sycl::fabs(static_cast(in_vec.z))); + absmax_val = + sycl::max(absmax_val, sycl::fabs(static_cast(in_vec.w))); + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = + sycl::max(absmax_val, sycl::fabs(static_cast(input[i]))); + } + + return absmax_val; +} + +template +class segmented_max_reduction { + private: + float* scale; + const scalar_t* input; + int64_t num_elems; + + public: + segmented_max_reduction(float* scale_, const scalar_t* input_, + int64_t num_elems_) + : scale(scale_), input(input_), num_elems(num_elems_) {} + void operator()(sycl::nd_item<1> item) const { + auto& cache = + *sycl::ext::oneapi::group_local_memory_for_overwrite( + item.get_group()); + int64_t i = item.get_global_linear_id(); + + // First store maximum for all values processes by + // the current thread in cache[item.get_local_id(0)] + float tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = sycl::max(tmp, sycl::fabs(x)); + i += item.get_local_range(0) * item.get_group_range(0); + } + cache[item.get_local_id(0)] = tmp; + + group_barrier(item.get_group()); + + // Now perform parallel reduction within the thread block + int ib = item.get_local_range(0) / 2; + while (ib != 0) { + if (item.get_local_id(0) < ib && + cache[item.get_local_id(0) + ib] > cache[item.get_local_id(0)]) { + cache[item.get_local_id(0)] = cache[item.get_local_id(0) + ib]; + } + group_barrier(item.get_group()); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + // TODO: Do we need if statement? + if (item.get_local_id(0) == 0) { + using atomic_t = + sycl::atomic_ref; + atomic_t atomic_max(*scale); + atomic_max.fetch_max(cache[0] / fp8::quant_type_max_v); + } + } +}; + +} // namespace vllm diff --git a/csrc/xpu/quantization/fp8/quant_utils.hpp b/csrc/xpu/quantization/fp8/quant_utils.h similarity index 69% rename from csrc/xpu/quantization/fp8/quant_utils.hpp rename to csrc/xpu/quantization/fp8/quant_utils.h index 701dcae..3e5fcd6 100644 --- a/csrc/xpu/quantization/fp8/quant_utils.hpp +++ b/csrc/xpu/quantization/fp8/quant_utils.h @@ -1,4 +1,9 @@ +#pragma once +#include + #include +#include +#include namespace vllm { @@ -10,17 +15,44 @@ enum class Fp8KVCacheDataType { namespace fp8 { -template -__inline__ Tout scaled_convert(const Tin& x, const float scale) { - if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { - return static_cast(x / scale); - } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { - return static_cast(x / scale); - } +template +struct alignas(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; - assert(false); - return {}; // Squash missing return statement warning -} +template +struct alignas(4) dtypex4_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Unsupported cache type for dtypex4_t"); + dtype_t x; + dtype_t y; + dtype_t z; + dtype_t w; +}; + +template || + std::is_same_v>> +struct quant_type_max { + static constexpr T val() { return std::numeric_limits::max(); } +}; + +template +static constexpr T quant_type_max_v = quant_type_max::val(); + +template || + std::is_same_v>> +struct min_scaling_factor { + static inline float val() { return 1.0f / (quant_type_max_v * 512.0f); } +}; // Used by vectorization_utils to copy/convert one element template @@ -31,40 +63,36 @@ struct CopyWithScaleOp { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { dst = static_cast(src); } else { - dst = fp8::scaled_convert(src, scale); + float x = (float)src / scale; + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + dst = static_cast(x); + } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { + dst = static_cast(x); + } } } }; -template -struct alignas(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; +// convert a float value to fp8 type with scaling +template +struct ConvertWithScaleOp { + float scale; -template -struct alignas(4) cachex4_t { - static_assert(std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v, - "Unsupported cache type for cachex4_t"); - cache_t x; - cache_t y; - cache_t z; - cache_t w; + inline void operator()(fp8_type& dst, float const src) const { + float x = is_scale_inverted ? (src * scale) : (src / scale); + const float fp8_max = static_cast(quant_type_max_v); + float r = sycl::fmax(-fp8_max, sycl::fmin(x, fp8_max)); + dst = static_cast(r); + } }; // The vector width is fixed at 4 to avoid excessive branching in the kernel, // which could degrade performance. -template -void scaled_convert_vec(const scalar_t* src, cache_t* dst, int num_elems, +template +void scaled_convert_vec(const scalar_t* src, dtype_t* dst, int num_elems, int local_idx, int local_range, ScaOp&& scalar_op) { using srcx4_t = vec4_t; - using distx4_t = cachex4_t; + using distx4_t = dtypex4_t; int64_t const num_vec_elems = num_elems >> 2; @@ -88,6 +116,7 @@ void scaled_convert_vec(const scalar_t* src, cache_t* dst, int num_elems, scalar_op(dst[i], src[i]); } } + } // namespace fp8 } // namespace vllm diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index 1c8fe57..5f76399 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -42,6 +42,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kXPU, &rotary_embedding); + + // Compute FP8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " + "()"); + ops.impl("static_scaled_fp8_quant", torch::kXPU, &static_scaled_fp8_quant); + + // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " + "-> " + "()"); + ops.impl("dynamic_scaled_fp8_quant", torch::kXPU, &dynamic_scaled_fp8_quant); + + // Compute dynamic-per-token FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " + "Tensor! scale, Tensor? scale_ub) -> " + "()"); + ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kXPU, + &dynamic_per_token_scaled_fp8_quant); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/tests/ops/fp8_quant_op.py b/tests/ops/fp8_quant_op.py new file mode 100644 index 0000000..54ec06a --- /dev/null +++ b/tests/ops/fp8_quant_op.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional, Union + +import torch + +import tests.register_ops as ops + +# Add parent directory to Python path +# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) #noqa: E501 + + +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, + output: Optional[torch.Tensor] = None, + fp8_dtype: torch.dtype = torch.float8_e5m2, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[tuple[int, int], torch.Size] = input.shape + out_dtype: torch.dtype = fp8_dtype + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input.contiguous(), + scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + assert scale.numel() == 1, f"{scale.shape}" + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale diff --git a/tests/register_ops.py b/tests/register_ops.py index 47df662..01f5220 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -78,3 +78,23 @@ def reshape_and_cache_flash( k_scale, v_scale, ) + + +def static_scaled_fp8_quant(out: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor) -> None: + torch.ops._C.static_scaled_fp8_quant(out, input, scale) + + +def dynamic_scaled_fp8_quant(out: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor) -> None: + torch.ops._C.dynamic_scaled_fp8_quant(out, input, scale) + + +def dynamic_per_token_scaled_fp8_quant( + out: torch.Tensor, + input: torch.Tensor, + scales: torch.Tensor, + scale_ub: Optional[torch.Tensor] = None, +) -> None: + torch.ops._C.dynamic_per_token_scaled_fp8_quant(out, input, scales, + scale_ub) diff --git a/tests/test_fp8_quant.py b/tests/test_fp8_quant.py new file mode 100644 index 0000000..fac95ab --- /dev/null +++ b/tests/test_fp8_quant.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from typing import Optional, Union + +import numpy as np +import pytest +import torch + +from tests.ops.fp8_quant_op import scaled_fp8_quant + + +def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: + return torch.as_tensor(x, dtype=torch.float32, device="xpu") + + +def ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype=torch.float8_e5m2): + + fp8_traits = torch.finfo(fp8_dtype) + fp8_traits_max = fp8_traits.max + fp8_traits_min = fp8_traits.min + fp8_max = as_float32_tensor(fp8_traits_max) + one = as_float32_tensor(1.0) + + # For fp8, in order to match the xpu kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. + + x_max = as_float32_tensor(x.abs().max()) + ref_scale = x_max / fp8_max + ref_iscale = one / ref_scale + ref_out = ((as_float32_tensor(x) * ref_iscale).clamp( + fp8_traits_min, fp8_traits_max).to(fp8_dtype)) + return ref_out, ref_scale.view((1, )) + + +def ref_dynamic_per_token_quant( + x: torch.tensor, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.tensor] = None +) -> tuple[torch.tensor, torch.tensor]: + + assert quant_dtype in [torch.float8_e5m2, torch.float8_e4m3fn] + # if scale_ub is not None: + # assert quant_dtype == FP8_DTYPE + + qtype_traits = torch.finfo(quant_dtype) + qtype_traits_max = qtype_traits.max + qtype_traits_min = qtype_traits.min + qtype_max = as_float32_tensor(qtype_traits_max) + s_1 = as_float32_tensor(1.0) + s_512 = as_float32_tensor(512.0) + + # For fp8, in order to match the cuda kernel output, we have to do exactly + # the same operations as in the corresponding fp8 kernel to prevent + # rounding errors. + + # Compute scales + x_token_max, _ = x.abs().max(dim=-1) + x_token_max = as_float32_tensor(x_token_max) + if scale_ub is not None: + x_token_max = x_token_max.clamp(max=scale_ub) + scales = (x_token_max / qtype_max)[:, None] + + # Quant + min_scaling_factor = s_1 / (qtype_max * s_512) + scales = scales.clamp(min=min_scaling_factor) + torch_out = as_float32_tensor(x) / scales + torch_out = torch_out.clamp(qtype_traits_min, + qtype_traits_max).to(quant_dtype) + + return torch_out, scales + + +def assert_close_percentage(a: torch.Tensor, + b: torch.Tensor, + mismatch_threshold: float = 0.01): + """ + Assert that two tensors are close within a mismatch percentage. + + Args: + a (torch.Tensor): First tensor. + b (torch.Tensor): Second tensor. + mismatch_threshold (float): + Allowed mismatch ratio (0.01 = 1% mismatch allowed). + + Raises: + AssertionError: If mismatch percentage exceeds the threshold. + """ + if a.shape != b.shape: + raise AssertionError(f"Shape mismatch: {a.shape} vs {b.shape}") + + mismatch_mask = a != b + mismatch_count = mismatch_mask.sum().item() + total_count = a.numel() + mismatch_ratio = mismatch_count / total_count + + if mismatch_ratio > mismatch_threshold: + raise AssertionError( + f"Tensors differ in {mismatch_ratio * 100:.2f}% of elements " + f"(allowed {mismatch_threshold * 100:.2f}%)") + + +def seed_everything(seed): + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [ + 1, + 2, + 3, + 4, + 16, + 67, + 768, + 2048, + 5120, + 5137, + 8192, + 8193, +] # Arbitrary values for testing +HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases +NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing +SCALE_UBS = [True, False] +SEEDS = [0] +FP8_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("fp8_dtype", FP8_DTYPES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_dynamic_per_tensor_fp8_quant( + num_tokens: int, + hidden_size: int, + fp8_dtype: torch.dtype, + dtype: torch.dtype, + seed: int, +) -> None: + seed_everything(seed) + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="xpu") + + ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype) + + ops_out, ops_scale = scaled_fp8_quant(x, fp8_dtype=fp8_dtype) + + torch.testing.assert_close(ref_scale, ops_scale) + torch.testing.assert_close(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("scale_ub", SCALE_UBS) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("fp8_dtype", FP8_DTYPES) +@torch.inference_mode() +def test_dynamic_per_token_fp8_quant( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + scale_ub: bool, + seed: int, + fp8_dtype: torch.dtype, +) -> None: + seed_everything(seed) + + x = (torch.rand(num_tokens, hidden_size, dtype=dtype, device="xpu") + 1e-6 + ) # avoid nans + + scale_ub = torch.mean(x).to(dtype=torch.float32, + device="xpu") if scale_ub else None + ref_out, ref_scales = ref_dynamic_per_token_quant(x, fp8_dtype, scale_ub) + + ops_out, ops_scales = scaled_fp8_quant(x, + scale_ub=scale_ub, + use_per_token_if_dynamic=True, + fp8_dtype=fp8_dtype) + + torch.testing.assert_close(ref_scales, ops_scales) + assert_close_percentage( + ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32), + mismatch_threshold=0.005, + ) # 0.5% mismatch allowed + + +# Regression test for a case with large activations where an int32 index cannot +# represent the number of elements. +@torch.inference_mode() +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("fp8_dtype", FP8_DTYPES) +def test_fp8_quant_large(seed: int, fp8_dtype: torch.dtype) -> None: + seed_everything(seed) + + num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings + hidden_size = 1152 # Smallest hidden_size to reproduce the error + dtype = torch.bfloat16 + + x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="xpu") + ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x, fp8_dtype) + + ops_out, _ = scaled_fp8_quant(x, scale, fp8_dtype=fp8_dtype) + + # Minimize memory footprint in this test by freeing x and upconverting + # the outputs in place. (torch.allclose does not support fp8) + del x + ref_out = ref_out.to(dtype=dtype) + ops_out = ops_out.to(dtype=dtype) + + torch.testing.assert_close(ref_out, ops_out) + + +if __name__ == "__main__": + test_dynamic_per_tensor_fp8_quant(1024, 1024, torch.float8_e5m2, + torch.float16, 0)