Skip to content

Commit eace8bf

Browse files
authored
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing #3208 It implements dynamic per-tensor scaling (see #4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in #3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
1 parent 1e8f425 commit eace8bf

File tree

10 files changed

+385
-21
lines changed

10 files changed

+385
-21
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
167167
"csrc/layernorm_kernels.cu"
168168
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
169169
"csrc/quantization/gptq/q_gemm.cu"
170+
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
170171
"csrc/cuda_utils_kernels.cu"
171172
"csrc/moe_align_block_size_kernels.cu"
172173
"csrc/pybind.cpp")

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ void gptq_shuffle(
146146
torch::Tensor q_perm,
147147
int bit);
148148

149+
void scaled_fp8_quant(
150+
torch::Tensor& out,
151+
torch::Tensor& input,
152+
torch::Tensor& scale);
153+
149154
void moe_align_block_size(
150155
torch::Tensor topk_ids,
151156
int num_experts,

csrc/pybind.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7373
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
7474
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
7575
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
76+
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
7677
ops.def(
7778
"moe_align_block_size",
7879
&moe_align_block_size,
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <torch/extension.h>
3+
#include <c10/cuda/CUDAGuard.h>
4+
5+
#include <cmath>
6+
7+
#include "cuda_compat.h"
8+
#include "dispatch_utils.h"
9+
10+
namespace vllm {
11+
12+
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
13+
float old;
14+
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
15+
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
16+
17+
return old;
18+
}
19+
20+
// Compute the absolute maximum m of the input tensor and store
21+
// m / float8_e4m3::max() in *scale. Each thread block performs a
22+
// reduction tree and the memory in scale is atomically updated.
23+
// So to get the right answer, *scale needs to be initialized to
24+
// a value <= 0.0 and we need to wait for all thread blocks to
25+
// finish before consuming *scale.
26+
template<typename scalar_t>
27+
__global__ void segmented_max_reduction(
28+
float* __restrict__ scale,
29+
const scalar_t* __restrict__ input,
30+
int64_t num_elems) {
31+
__shared__ float cache[1024];
32+
int i = blockDim.x * blockIdx.x + threadIdx.x;
33+
34+
// First store maximum for all values processes by
35+
// the current thread in cache[threadIdx.x]
36+
scalar_t tmp = 0.0;
37+
while (i < num_elems) {
38+
float x = static_cast<float>(input[i]);
39+
tmp = max(tmp, fabs(x));
40+
i += blockDim.x * gridDim.x;
41+
}
42+
cache[threadIdx.x] = tmp;
43+
44+
__syncthreads();
45+
46+
// Now perform parallel reduction within the thread block
47+
int ib = blockDim.x / 2;
48+
while (ib != 0) {
49+
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
50+
cache[threadIdx.x] = cache[threadIdx.x + ib];
51+
}
52+
__syncthreads();
53+
ib /= 2;
54+
}
55+
// Finally, since cache[0] contains the maximum for this thread block,
56+
// atomically write the max to the target location
57+
if (threadIdx.x == 0) {
58+
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
59+
}
60+
}
61+
62+
template<typename scalar_t>
63+
__global__ void scaled_fp8_quant_kernel(
64+
c10::Float8_e4m3fn* __restrict__ out,
65+
const scalar_t* __restrict__ input,
66+
const float* __restrict__ scale,
67+
int64_t num_elems) {
68+
int i = blockDim.x * blockIdx.x + threadIdx.x;
69+
while (i < num_elems) {
70+
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
71+
i += blockDim.x * gridDim.x;
72+
}
73+
}
74+
75+
} // namespace vllm
76+
77+
void scaled_fp8_quant(
78+
torch::Tensor& out, // [..., d]
79+
torch::Tensor& input, // [..., d]
80+
torch::Tensor& scale) // [1]
81+
{
82+
int64_t num_tokens = input.numel() / input.size(-1);
83+
int64_t num_elems = input.numel();
84+
dim3 grid(num_tokens);
85+
dim3 block(1024);
86+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
87+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
88+
VLLM_DISPATCH_FLOATING_TYPES(
89+
input.scalar_type(),
90+
"scaled_fp8_quant_kernel",
91+
[&] {
92+
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
93+
scale.data_ptr<float>(),
94+
input.data_ptr<scalar_t>(),
95+
num_elems);
96+
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
97+
out.data_ptr<c10::Float8_e4m3fn>(),
98+
input.data_ptr<scalar_t>(),
99+
scale.data_ptr<float>(),
100+
num_elems);
101+
});
102+
}
103+

vllm/_custom_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Optional, Tuple
22

33
import torch
44

@@ -153,6 +153,14 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
153153
size_n, size_k)
154154

155155

156+
# fp8
157+
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
158+
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
159+
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
160+
vllm_ops.scaled_fp8_quant(output, input, scale)
161+
return output, scale
162+
163+
156164
# moe
157165
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
158166
block_size: int, sorted_token_ids: torch.Tensor,
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 32,
5+
"BLOCK_SIZE_K": 128,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 4,
8+
"num_stages": 4
9+
},
10+
"2": {
11+
"BLOCK_SIZE_M": 128,
12+
"BLOCK_SIZE_N": 64,
13+
"BLOCK_SIZE_K": 128,
14+
"GROUP_SIZE_M": 1,
15+
"num_warps": 4,
16+
"num_stages": 4
17+
},
18+
"4": {
19+
"BLOCK_SIZE_M": 64,
20+
"BLOCK_SIZE_N": 64,
21+
"BLOCK_SIZE_K": 64,
22+
"GROUP_SIZE_M": 64,
23+
"num_warps": 4,
24+
"num_stages": 4
25+
},
26+
"8": {
27+
"BLOCK_SIZE_M": 64,
28+
"BLOCK_SIZE_N": 128,
29+
"BLOCK_SIZE_K": 256,
30+
"GROUP_SIZE_M": 64,
31+
"num_warps": 8,
32+
"num_stages": 4
33+
},
34+
"16": {
35+
"BLOCK_SIZE_M": 64,
36+
"BLOCK_SIZE_N": 256,
37+
"BLOCK_SIZE_K": 128,
38+
"GROUP_SIZE_M": 1,
39+
"num_warps": 8,
40+
"num_stages": 4
41+
},
42+
"24": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_N": 256,
45+
"BLOCK_SIZE_K": 128,
46+
"GROUP_SIZE_M": 1,
47+
"num_warps": 8,
48+
"num_stages": 4
49+
},
50+
"32": {
51+
"BLOCK_SIZE_M": 64,
52+
"BLOCK_SIZE_N": 128,
53+
"BLOCK_SIZE_K": 128,
54+
"GROUP_SIZE_M": 16,
55+
"num_warps": 8,
56+
"num_stages": 4
57+
},
58+
"48": {
59+
"BLOCK_SIZE_M": 64,
60+
"BLOCK_SIZE_N": 128,
61+
"BLOCK_SIZE_K": 128,
62+
"GROUP_SIZE_M": 32,
63+
"num_warps": 4,
64+
"num_stages": 4
65+
},
66+
"64": {
67+
"BLOCK_SIZE_M": 64,
68+
"BLOCK_SIZE_N": 128,
69+
"BLOCK_SIZE_K": 128,
70+
"GROUP_SIZE_M": 16,
71+
"num_warps": 8,
72+
"num_stages": 4
73+
},
74+
"96": {
75+
"BLOCK_SIZE_M": 64,
76+
"BLOCK_SIZE_N": 32,
77+
"BLOCK_SIZE_K": 256,
78+
"GROUP_SIZE_M": 32,
79+
"num_warps": 4,
80+
"num_stages": 4
81+
},
82+
"128": {
83+
"BLOCK_SIZE_M": 64,
84+
"BLOCK_SIZE_N": 128,
85+
"BLOCK_SIZE_K": 128,
86+
"GROUP_SIZE_M": 16,
87+
"num_warps": 8,
88+
"num_stages": 4
89+
},
90+
"256": {
91+
"BLOCK_SIZE_M": 128,
92+
"BLOCK_SIZE_N": 256,
93+
"BLOCK_SIZE_K": 128,
94+
"GROUP_SIZE_M": 1,
95+
"num_warps": 8,
96+
"num_stages": 4
97+
},
98+
"512": {
99+
"BLOCK_SIZE_M": 128,
100+
"BLOCK_SIZE_N": 256,
101+
"BLOCK_SIZE_K": 128,
102+
"GROUP_SIZE_M": 16,
103+
"num_warps": 8,
104+
"num_stages": 4
105+
},
106+
"1024": {
107+
"BLOCK_SIZE_M": 128,
108+
"BLOCK_SIZE_N": 256,
109+
"BLOCK_SIZE_K": 128,
110+
"GROUP_SIZE_M": 64,
111+
"num_warps": 8,
112+
"num_stages": 4
113+
},
114+
"1536": {
115+
"BLOCK_SIZE_M": 128,
116+
"BLOCK_SIZE_N": 256,
117+
"BLOCK_SIZE_K": 128,
118+
"GROUP_SIZE_M": 32,
119+
"num_warps": 8,
120+
"num_stages": 4
121+
},
122+
"2048": {
123+
"BLOCK_SIZE_M": 128,
124+
"BLOCK_SIZE_N": 256,
125+
"BLOCK_SIZE_K": 128,
126+
"GROUP_SIZE_M": 64,
127+
"num_warps": 8,
128+
"num_stages": 4
129+
},
130+
"3072": {
131+
"BLOCK_SIZE_M": 128,
132+
"BLOCK_SIZE_N": 256,
133+
"BLOCK_SIZE_K": 128,
134+
"GROUP_SIZE_M": 32,
135+
"num_warps": 8,
136+
"num_stages": 4
137+
},
138+
"4096": {
139+
"BLOCK_SIZE_M": 128,
140+
"BLOCK_SIZE_N": 256,
141+
"BLOCK_SIZE_K": 128,
142+
"GROUP_SIZE_M": 64,
143+
"num_warps": 8,
144+
"num_stages": 4
145+
}
146+
}

0 commit comments

Comments
 (0)