-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Description
🚀 The feature, motivation and pitch
When using kv_cache_dtype="fp8
on Hopper, we are using FA3 FP8 backend. This means we are quantizing Keys, Values, and Queries and are using the complete FP8 forward pass of FA3.
Whilst this saves KV cache space and leads to TTFT speedups it currently leads to slowdowns in ITL which are quite considerable (see below)
The biggest issue are that quantizing the queries and the K, V causes overheads. Furthermore FP8 path of FA3 is very bad for decoding for GPT-OSS (head_dim=64).
To adress this we propose 3 optimizations:
1/ Only quantize the full attention layers and leave the sliding window layers in BF16 --> #24912 (specific to GPT-OSS architecture)
2/ [merged] Add an environment flag such that for static scales torch.compile can fuse the quantization of the queries into previous operations --> #24914 (applies generally)
3/ [merged] Optimize the tiling configuration for head_dim 64 of FA3 --> vllm-project/flash-attention#91
We run bench serve
with 25k and 1k inputs to illustrate the overheads during decoding (always running with "cudagraph_mode": "FULL_AND_PIECEWISE"
). On mainline turning on KV quantization increases ITL at 1k context from 4.83ms to 5.06ms.
With our proposed optimizations, this get's optimized to 4.86ms.
Already at 25k inputs we then start seeing actual speedups from FP8 quantization.
With the skipping of SW layers quantization, we quantize less than normally. We also checked the accuracy via gpt-oss evals and are within the usual variations.
model | applied optimization | input | output | median ITL [ms] | median ttft [ms] |
---|---|---|---|---|---|
gpt-oss (full bf 16) | mainline | 25000 | 200 | 5.08 | 1252.43 |
gpt-oss (full bf 16) | mainline | 1000 | 200 | 4.83 | 76.16 |
kv fp8 | mainline | 25000 | 200 | 5.38 | 1211.33 |
kv fp8 | mainline | 1000 | 200 | 5.06 | 76.36 |
gpt-oss-KV8-skip-sw | 1 | 25000 | 200 | 5.23 | 1210.6 |
gpt-oss-KV8-skip-sw | 1 | 1000 | 200 | 4.92 | 76.37 |
gpt-oss-KV8-fuse-q-quant | 2 | 25000 | 200 | 5.3 | 1208.56 |
gpt-oss-KV8-fuse-q-quant | 2 | 1000 | 200 | 4.98 | 78.89 |
gpt-oss-KV8-skip-and-fuse-q-quant | 1,2 | 25000 | 200 | 5.23 | 1209.68 |
gpt-oss-KV8-skip-and-fuse-q-quant | 1,2 | 1000 | 200 | 4.91 | 76.06 |
gpt-oss-KV8-skip-and-fuse-q-quant-tuned-fa3tiles | 1,2,3 | 25000 | 200 | 5.06 | 1214.72 |
gpt-oss-KV8-skip-and-fuse-q-quant-tuned-fa3tiles | 1,2,3 | 1000 | 200 | 4.86 | 79.82 |
Alternatives
No response
Additional context
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.