|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/DeviceGuard.h> |
| 3 | +#include <ATen/xpu/XPUContext.h> |
| 4 | + |
| 5 | +#include <sycl/sycl.hpp> |
| 6 | + |
| 7 | +#include "xpu/dispatch_utils.h" |
| 8 | + |
| 9 | +#include "fp8_quant.h" |
| 10 | +#include "utils.h" |
| 11 | + |
| 12 | + |
| 13 | +namespace vllm { |
| 14 | + |
| 15 | +template <typename scalar_t, typename fp8_type> |
| 16 | +class scaled_fp8_quant_kernel { |
| 17 | + private: |
| 18 | + fp8_type* out; |
| 19 | + const scalar_t* input; |
| 20 | + const float* scale; |
| 21 | + int64_t num_elems; |
| 22 | + |
| 23 | + public: |
| 24 | + scaled_fp8_quant_kernel( |
| 25 | + fp8_type* out_, |
| 26 | + const scalar_t* input_, |
| 27 | + const float* scale_, |
| 28 | + int64_t num_elems_) |
| 29 | + : out(out_), input(input_), scale(scale_), num_elems(num_elems_) {} |
| 30 | + void operator()(sycl::nd_item<1> item) const { |
| 31 | + int tid = item.get_global_linear_id(); |
| 32 | + |
| 33 | + // Invert the scale so that we can use multiplications to avoid expensive |
| 34 | + // division. |
| 35 | + const float inverted_scale = 1.0f / (*scale); |
| 36 | + scaled_fp8_conversion_vec<scalar_t, true>( |
| 37 | + out, |
| 38 | + input, |
| 39 | + inverted_scale, |
| 40 | + num_elems, |
| 41 | + tid, |
| 42 | + item.get_local_range(0) * item.get_group_range(0)); |
| 43 | + } |
| 44 | +}; |
| 45 | + |
| 46 | +template <typename scalar_t, typename fp8_type> |
| 47 | +class dynamic_per_token_scaled_fp8_quant_kernel { |
| 48 | + private: |
| 49 | + fp8_type* out; |
| 50 | + float* scale; |
| 51 | + scalar_t const* input; |
| 52 | + float const* scale_ub; |
| 53 | + const int hidden_size; |
| 54 | + |
| 55 | + public: |
| 56 | + dynamic_per_token_scaled_fp8_quant_kernel( |
| 57 | + fp8_type* out_, |
| 58 | + float* scale_, |
| 59 | + scalar_t const* input_, |
| 60 | + float const* scale_ub_, |
| 61 | + const int hidden_size_) |
| 62 | + : out(out_), |
| 63 | + scale(scale_), |
| 64 | + input(input_), |
| 65 | + scale_ub(scale_ub_), |
| 66 | + hidden_size(hidden_size_) {} |
| 67 | + |
| 68 | + void operator()(sycl::nd_item<1> item) const { |
| 69 | + int const tid = item.get_local_id(0); |
| 70 | + int const token_idx = item.get_group(0); |
| 71 | + |
| 72 | + // sycl::ext::oneapi::experimental::printf( |
| 73 | + // "token_idx: %d, tid: %d, hidden_size: %d, group_range: %d\n", |
| 74 | + // token_idx, |
| 75 | + // tid, |
| 76 | + // hidden_size, |
| 77 | + // item.get_local_range(0)); |
| 78 | + |
| 79 | + // Use int64 to avoid overflowing an int32 when calculating this offset |
| 80 | + int64_t offset = static_cast<int64_t>(token_idx) * hidden_size; |
| 81 | + scalar_t const* token_input = &input[offset]; |
| 82 | + fp8_type* token_output = &out[offset]; |
| 83 | + |
| 84 | + // For vectorization, token_input and token_output pointers need to be |
| 85 | + // aligned at 8-byte and 4-byte addresses respectively. |
| 86 | + bool const can_vectorize = hidden_size % 4 == 0; |
| 87 | + |
| 88 | + float absmax_val = 0.0f; |
| 89 | + if (can_vectorize) { |
| 90 | + absmax_val = thread_max_vec( |
| 91 | + token_input, hidden_size, tid, item.get_local_range(0)); |
| 92 | + } else { |
| 93 | + for (int i = tid; i < hidden_size; i += item.get_local_range(0)) { |
| 94 | + float const x = static_cast<float>(token_input[i]); |
| 95 | + absmax_val = sycl::max(absmax_val, sycl::fabs(x)); |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + float const block_absmax_val_maybe = sycl::reduce_over_group( |
| 100 | + item.get_group(), absmax_val, sycl::maximum<float>()); |
| 101 | + // __shared__ float token_scale; |
| 102 | + auto& token_scale = |
| 103 | + *sycl::ext::oneapi::group_local_memory_for_overwrite<float[1]>( |
| 104 | + item.get_group()); |
| 105 | + if (tid == 0) { |
| 106 | + if (scale_ub) { |
| 107 | + token_scale[0] = sycl::min(block_absmax_val_maybe, *scale_ub); |
| 108 | + } else { |
| 109 | + token_scale[0] = block_absmax_val_maybe; |
| 110 | + } |
| 111 | + // token scale computation |
| 112 | + token_scale[0] = sycl::max( |
| 113 | + token_scale[0] / quant_type_max_v<fp8_type>, |
| 114 | + min_scaling_factor<fp8_type>::val()); |
| 115 | + scale[token_idx] = token_scale[0]; |
| 116 | + } |
| 117 | + group_barrier(item.get_group()); |
| 118 | + |
| 119 | + // Note that we don't use inverted scales so we can match FBGemm impl. |
| 120 | + const float inverted_scale = 1.0f / (token_scale[0]); |
| 121 | + if (can_vectorize) { |
| 122 | + scaled_fp8_conversion_vec<scalar_t, true>( |
| 123 | + token_output, |
| 124 | + token_input, |
| 125 | + inverted_scale, |
| 126 | + hidden_size, |
| 127 | + tid, |
| 128 | + item.get_local_range(0)); |
| 129 | + } else { |
| 130 | + for (int i = tid; i < hidden_size; i += item.get_local_range(0)) { |
| 131 | + token_output[i] = scaled_fp8_conversion<true, fp8_type>( |
| 132 | + static_cast<float>(token_input[i]), inverted_scale, tid, token_idx); |
| 133 | + } |
| 134 | + } |
| 135 | + } |
| 136 | +}; |
| 137 | + |
| 138 | +} // namespace vllm |
| 139 | + |
| 140 | +void static_scaled_fp8_quant( |
| 141 | + torch::Tensor& out, // [..., d] |
| 142 | + torch::Tensor const& input, // [..., d] |
| 143 | + torch::Tensor const& scale) // [1] |
| 144 | +{ |
| 145 | + int64_t num_tokens = input.numel() / input.size(-1); |
| 146 | + int64_t num_elems = input.numel(); |
| 147 | + sycl::range<1> grid(num_tokens); |
| 148 | + sycl::range<1> block(1024); |
| 149 | + at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); |
| 150 | + at::DeviceGuard device_guard(curDevice); |
| 151 | + |
| 152 | + auto stream = at::xpu::getCurrentXPUStream().queue(); |
| 153 | + // TODO: change name? |
| 154 | + VLLM_DISPATCH_FLOATING_TYPES( |
| 155 | + input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { |
| 156 | + VLLM_DISPATCH_FP8_TYPES( |
| 157 | + out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { |
| 158 | + // Launch the kernel |
| 159 | + stream.submit([&](sycl::handler& cgh) { |
| 160 | + auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>( |
| 161 | + out.data_ptr<fp8_t>(), |
| 162 | + input.data_ptr<scalar_t>(), |
| 163 | + scale.data_ptr<float>(), |
| 164 | + num_elems); |
| 165 | + cgh.parallel_for( |
| 166 | + sycl::nd_range<1>(grid * block, block), kernel); |
| 167 | + }); |
| 168 | + }); |
| 169 | + }); |
| 170 | +} |
| 171 | + |
| 172 | +void dynamic_scaled_fp8_quant( |
| 173 | + torch::Tensor& out, // [..., d] |
| 174 | + torch::Tensor const& input, // [..., d] |
| 175 | + torch::Tensor& scale) // [1] |
| 176 | +{ |
| 177 | + int64_t num_tokens = input.numel() / input.size(-1); |
| 178 | + int64_t num_elems = input.numel(); |
| 179 | + sycl::range<1> grid(num_tokens); |
| 180 | + sycl::range<1> block(1024); |
| 181 | + at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); |
| 182 | + at::DeviceGuard device_guard(curDevice); |
| 183 | + |
| 184 | + auto stream = at::xpu::getCurrentXPUStream().queue(); |
| 185 | + VLLM_DISPATCH_FLOATING_TYPES( |
| 186 | + input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { |
| 187 | + VLLM_DISPATCH_FP8_TYPES( |
| 188 | + out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { |
| 189 | + // Launch the kernel |
| 190 | + stream.submit([&](sycl::handler& cgh) { |
| 191 | + auto max_reduce_kernel = |
| 192 | + vllm::segmented_max_reduction<scalar_t, fp8_t>( |
| 193 | + scale.data_ptr<float>(), |
| 194 | + input.data_ptr<scalar_t>(), |
| 195 | + num_elems); |
| 196 | + cgh.parallel_for( |
| 197 | + sycl::nd_range<1>(grid * block, block), max_reduce_kernel); |
| 198 | + }); |
| 199 | + stream.submit([&](sycl::handler& cgh) { |
| 200 | + auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>( |
| 201 | + out.data_ptr<fp8_t>(), |
| 202 | + input.data_ptr<scalar_t>(), |
| 203 | + scale.data_ptr<float>(), |
| 204 | + num_elems); |
| 205 | + cgh.parallel_for( |
| 206 | + sycl::nd_range<1>(grid * block, block), kernel); |
| 207 | + }); |
| 208 | + }); |
| 209 | + }); |
| 210 | +} |
| 211 | + |
| 212 | +void dynamic_per_token_scaled_fp8_quant( |
| 213 | + torch::Tensor& out, // [..., d] |
| 214 | + torch::Tensor const& input, // [..., d] |
| 215 | + torch::Tensor& scales, |
| 216 | + std::optional<at::Tensor> const& scale_ub) { |
| 217 | + TORCH_CHECK(input.is_contiguous()); |
| 218 | + TORCH_CHECK(out.is_contiguous()); |
| 219 | + |
| 220 | + int const hidden_size = input.size(-1); |
| 221 | + int const num_tokens = input.numel() / hidden_size; |
| 222 | + sycl::range<1> grid(num_tokens); |
| 223 | + sycl::range<1> block(std::min(hidden_size, 1024)); |
| 224 | + |
| 225 | + at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device()); |
| 226 | + at::DeviceGuard device_guard(curDevice); |
| 227 | + |
| 228 | + auto stream = at::xpu::getCurrentXPUStream().queue(); |
| 229 | + VLLM_DISPATCH_FLOATING_TYPES( |
| 230 | + input.scalar_type(), |
| 231 | + "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", |
| 232 | + [&] { |
| 233 | + VLLM_DISPATCH_FP8_TYPES( |
| 234 | + out.scalar_type(), |
| 235 | + "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", |
| 236 | + [&] { |
| 237 | + // Launch the kernel |
| 238 | + stream |
| 239 | + .submit([&](sycl::handler& cgh) { |
| 240 | + auto kernel = vllm::dynamic_per_token_scaled_fp8_quant_kernel< |
| 241 | + scalar_t, |
| 242 | + fp8_t>( |
| 243 | + out.data_ptr<fp8_t>(), |
| 244 | + scales.data_ptr<float>(), |
| 245 | + input.data_ptr<scalar_t>(), |
| 246 | + scale_ub.has_value() ? scale_ub->data_ptr<float>() |
| 247 | + : nullptr, |
| 248 | + hidden_size); |
| 249 | + cgh.parallel_for( |
| 250 | + sycl::nd_range<1>(grid * block, block), kernel); |
| 251 | + }) |
| 252 | + .wait(); |
| 253 | + }); |
| 254 | + }); |
| 255 | +} |
0 commit comments