Skip to content

Commit 5bbb0cc

Browse files
committed
merge redundant files
Signed-off-by: Zhu, Zufang <[email protected]>
1 parent 56ebb81 commit 5bbb0cc

File tree

10 files changed

+230
-310
lines changed

10 files changed

+230
-310
lines changed

csrc/xpu/cache.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <string>
66

77
#include "dispatch_utils.h"
8-
#include "quantization/fp8/quant_utils.hpp"
8+
#include "quantization/fp8/quant_utils.h"
99
#include "utils.h"
1010

1111
namespace vllm {
@@ -90,6 +90,25 @@ void call_reshape_and_cache(
9090
});
9191
}
9292

93+
// Used by vectorization_utils to copy/convert one element
94+
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
95+
struct CopyWithScaleOp {
96+
float scale;
97+
98+
inline void operator()(OutT& dst, const InT src) const {
99+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
100+
dst = static_cast<OutT>(src);
101+
} else {
102+
float x = (float)src / scale;
103+
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
104+
dst = static_cast<at::Float8_e4m3fn>(x);
105+
} else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
106+
dst = static_cast<at::Float8_e5m2>(x);
107+
}
108+
}
109+
}
110+
};
111+
93112
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
94113
void reshape_and_cache_flash_kernel(
95114
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
@@ -127,8 +146,8 @@ void reshape_and_cache_flash_kernel(
127146
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
128147
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
129148

130-
fp8::CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
131-
fp8::CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
149+
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
150+
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
132151
fp8::scaled_convert_vec(key_src, key_dst, n, local_idx, local_range, k_op);
133152
fp8::scaled_convert_vec(value_src, value_dst, n, local_idx, local_range,
134153
v_op);

csrc/xpu/dispatch_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
2222
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
2323

24-
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
24+
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
2525
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
2626
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
2727

2828
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
2929
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
30-
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
30+
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
3131
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
3232

3333
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.

csrc/xpu/quantization/fp8/fp8_quant.cpp

Lines changed: 54 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
#include "xpu/ops.h"
99

1010
#include "fp8_quant.h"
11-
#include "utils.h"
12-
11+
#include "quant_utils.h"
1312

1413
namespace vllm {
1514

@@ -22,25 +21,19 @@ class scaled_fp8_quant_kernel {
2221
int64_t num_elems;
2322

2423
public:
25-
scaled_fp8_quant_kernel(
26-
fp8_type* out_,
27-
const scalar_t* input_,
28-
const float* scale_,
29-
int64_t num_elems_)
24+
scaled_fp8_quant_kernel(fp8_type* out_, const scalar_t* input_,
25+
const float* scale_, int64_t num_elems_)
3026
: out(out_), input(input_), scale(scale_), num_elems(num_elems_) {}
3127
void operator()(sycl::nd_item<1> item) const {
3228
int tid = item.get_global_linear_id();
3329

3430
// Invert the scale so that we can use multiplications to avoid expensive
3531
// division.
3632
const float inverted_scale = 1.0f / (*scale);
37-
scaled_fp8_conversion_vec<scalar_t, true>(
38-
out,
39-
input,
40-
inverted_scale,
41-
num_elems,
42-
tid,
43-
item.get_local_range(0) * item.get_group_range(0));
33+
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale};
34+
fp8::scaled_convert_vec(input, out, num_elems, tid,
35+
item.get_local_range(0) * item.get_group_range(0),
36+
op);
4437
}
4538
};
4639

@@ -54,12 +47,10 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
5447
const int hidden_size;
5548

5649
public:
57-
dynamic_per_token_scaled_fp8_quant_kernel(
58-
fp8_type* out_,
59-
float* scale_,
60-
scalar_t const* input_,
61-
float const* scale_ub_,
62-
const int hidden_size_)
50+
dynamic_per_token_scaled_fp8_quant_kernel(fp8_type* out_, float* scale_,
51+
scalar_t const* input_,
52+
float const* scale_ub_,
53+
const int hidden_size_)
6354
: out(out_),
6455
scale(scale_),
6556
input(input_),
@@ -70,13 +61,6 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
7061
int const tid = item.get_local_id(0);
7162
int const token_idx = item.get_group(0);
7263

73-
// sycl::ext::oneapi::experimental::printf(
74-
// "token_idx: %d, tid: %d, hidden_size: %d, group_range: %d\n",
75-
// token_idx,
76-
// tid,
77-
// hidden_size,
78-
// item.get_local_range(0));
79-
8064
// Use int64 to avoid overflowing an int32 when calculating this offset
8165
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
8266
scalar_t const* token_input = &input[offset];
@@ -88,8 +72,8 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
8872

8973
float absmax_val = 0.0f;
9074
if (can_vectorize) {
91-
absmax_val = thread_max_vec(
92-
token_input, hidden_size, tid, item.get_local_range(0));
75+
absmax_val = thread_max_vec(token_input, hidden_size, tid,
76+
item.get_local_range(0));
9377
} else {
9478
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) {
9579
float const x = static_cast<float>(token_input[i]);
@@ -110,38 +94,33 @@ class dynamic_per_token_scaled_fp8_quant_kernel {
11094
token_scale[0] = block_absmax_val_maybe;
11195
}
11296
// token scale computation
113-
token_scale[0] = sycl::max(
114-
token_scale[0] / quant_type_max_v<fp8_type>,
115-
min_scaling_factor<fp8_type>::val());
97+
token_scale[0] =
98+
sycl::max(token_scale[0] / fp8::quant_type_max_v<fp8_type>,
99+
fp8::min_scaling_factor<fp8_type>::val());
116100
scale[token_idx] = token_scale[0];
117101
}
118102
group_barrier(item.get_group());
119103

120104
// Note that we don't use inverted scales so we can match FBGemm impl.
121105
const float inverted_scale = 1.0f / (token_scale[0]);
122106
if (can_vectorize) {
123-
scaled_fp8_conversion_vec<scalar_t, true>(
124-
token_output,
125-
token_input,
126-
inverted_scale,
127-
hidden_size,
128-
tid,
129-
item.get_local_range(0));
107+
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale};
108+
fp8::scaled_convert_vec(token_input, token_output, hidden_size, tid,
109+
item.get_local_range(0), op);
130110
} else {
131111
for (int i = tid; i < hidden_size; i += item.get_local_range(0)) {
132-
token_output[i] = scaled_fp8_conversion<true, fp8_type>(
133-
static_cast<float>(token_input[i]), inverted_scale, tid, token_idx);
112+
fp8::ConvertWithScaleOp<true, fp8_type> op{inverted_scale};
113+
op(token_output[i], token_input[i]);
134114
}
135115
}
136116
}
137117
};
138118

139-
} // namespace vllm
119+
} // namespace vllm
140120

141-
void static_scaled_fp8_quant(
142-
torch::Tensor& out, // [..., d]
143-
torch::Tensor const& input, // [..., d]
144-
torch::Tensor const& scale) // [1]
121+
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
122+
torch::Tensor const& input, // [..., d]
123+
torch::Tensor const& scale) // [1]
145124
{
146125
int64_t num_tokens = input.numel() / input.size(-1);
147126
int64_t num_elems = input.numel();
@@ -158,21 +137,18 @@ void static_scaled_fp8_quant(
158137
// Launch the kernel
159138
stream.submit([&](sycl::handler& cgh) {
160139
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>(
161-
out.data_ptr<fp8_t>(),
162-
input.data_ptr<scalar_t>(),
163-
scale.data_ptr<float>(),
164-
num_elems);
165-
cgh.parallel_for(
166-
sycl::nd_range<1>(grid * block, block), kernel);
140+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
141+
scale.data_ptr<float>(), num_elems);
142+
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
143+
kernel);
167144
});
168145
});
169146
});
170147
}
171148

172-
void dynamic_scaled_fp8_quant(
173-
torch::Tensor& out, // [..., d]
174-
torch::Tensor const& input, // [..., d]
175-
torch::Tensor& scale) // [1]
149+
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
150+
torch::Tensor const& input, // [..., d]
151+
torch::Tensor& scale) // [1]
176152
{
177153
int64_t num_tokens = input.numel() / input.size(-1);
178154
int64_t num_elems = input.numel();
@@ -190,30 +166,26 @@ void dynamic_scaled_fp8_quant(
190166
stream.submit([&](sycl::handler& cgh) {
191167
auto max_reduce_kernel =
192168
vllm::segmented_max_reduction<scalar_t, fp8_t>(
193-
scale.data_ptr<float>(),
194-
input.data_ptr<scalar_t>(),
169+
scale.data_ptr<float>(), input.data_ptr<scalar_t>(),
195170
num_elems);
196-
cgh.parallel_for(
197-
sycl::nd_range<1>(grid * block, block), max_reduce_kernel);
171+
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
172+
max_reduce_kernel);
198173
});
199174
stream.submit([&](sycl::handler& cgh) {
200175
auto kernel = vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>(
201-
out.data_ptr<fp8_t>(),
202-
input.data_ptr<scalar_t>(),
203-
scale.data_ptr<float>(),
204-
num_elems);
205-
cgh.parallel_for(
206-
sycl::nd_range<1>(grid * block, block), kernel);
176+
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
177+
scale.data_ptr<float>(), num_elems);
178+
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
179+
kernel);
207180
});
208181
});
209182
});
210183
}
211184

212185
void dynamic_per_token_scaled_fp8_quant(
213-
torch::Tensor& out, // [..., d]
214-
torch::Tensor const& input, // [..., d]
215-
torch::Tensor& scales,
216-
std::optional<at::Tensor> const& scale_ub) {
186+
torch::Tensor& out, // [..., d]
187+
torch::Tensor const& input, // [..., d]
188+
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
217189
TORCH_CHECK(input.is_contiguous());
218190
TORCH_CHECK(out.is_contiguous());
219191

@@ -228,26 +200,23 @@ void dynamic_per_token_scaled_fp8_quant(
228200
auto stream = at::xpu::getCurrentXPUStream().queue();
229201
VLLM_DISPATCH_FLOATING_TYPES(
230202
input.scalar_type(),
231-
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type",
232-
[&] {
203+
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
233204
VLLM_DISPATCH_FP8_TYPES(
234205
out.scalar_type(),
235-
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type",
236-
[&] {
206+
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
237207
// Launch the kernel
238208
stream
239209
.submit([&](sycl::handler& cgh) {
240-
auto kernel = vllm::dynamic_per_token_scaled_fp8_quant_kernel<
241-
scalar_t,
242-
fp8_t>(
243-
out.data_ptr<fp8_t>(),
244-
scales.data_ptr<float>(),
245-
input.data_ptr<scalar_t>(),
246-
scale_ub.has_value() ? scale_ub->data_ptr<float>()
247-
: nullptr,
248-
hidden_size);
249-
cgh.parallel_for(
250-
sycl::nd_range<1>(grid * block, block), kernel);
210+
auto kernel =
211+
vllm::dynamic_per_token_scaled_fp8_quant_kernel<
212+
scalar_t, fp8_t>(
213+
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
214+
input.data_ptr<scalar_t>(),
215+
scale_ub.has_value() ? scale_ub->data_ptr<float>()
216+
: nullptr,
217+
hidden_size);
218+
cgh.parallel_for(sycl::nd_range<1>(grid * block, block),
219+
kernel);
251220
})
252221
.wait();
253222
});

0 commit comments

Comments
 (0)