Skip to content

[Feature]: Make FP8 Attention fast for GPT-OSS w/ FA3 on Hopper #24916

@jmkuebler

Description

@jmkuebler

🚀 The feature, motivation and pitch

When using kv_cache_dtype="fp8 on Hopper, we are using FA3 FP8 backend. This means we are quantizing Keys, Values, and Queries and are using the complete FP8 forward pass of FA3.

Whilst this saves KV cache space and leads to TTFT speedups it currently leads to slowdowns in ITL which are quite considerable (see below)

The biggest issue are that quantizing the queries and the K, V causes overheads. Furthermore FP8 path of FA3 is very bad for decoding for GPT-OSS (head_dim=64).

To adress this we propose 3 optimizations:
1/ Only quantize the full attention layers and leave the sliding window layers in BF16 --> #24912 (specific to GPT-OSS architecture)
2/ [merged] Add an environment flag such that for static scales torch.compile can fuse the quantization of the queries into previous operations --> #24914 (applies generally)
3/ [merged] Optimize the tiling configuration for head_dim 64 of FA3 --> vllm-project/flash-attention#91

We run bench serve with 25k and 1k inputs to illustrate the overheads during decoding (always running with "cudagraph_mode": "FULL_AND_PIECEWISE"). On mainline turning on KV quantization increases ITL at 1k context from 4.83ms to 5.06ms.
With our proposed optimizations, this get's optimized to 4.86ms.

Already at 25k inputs we then start seeing actual speedups from FP8 quantization.

With the skipping of SW layers quantization, we quantize less than normally. We also checked the accuracy via gpt-oss evals and are within the usual variations.

model applied optimization input output median ITL [ms] median ttft [ms]
gpt-oss (full bf 16) mainline 25000 200 5.08 1252.43
gpt-oss (full bf 16) mainline 1000 200 4.83 76.16
kv fp8 mainline 25000 200 5.38 1211.33
kv fp8 mainline 1000 200 5.06 76.36
gpt-oss-KV8-skip-sw 1 25000 200 5.23 1210.6
gpt-oss-KV8-skip-sw 1 1000 200 4.92 76.37
gpt-oss-KV8-fuse-q-quant 2 25000 200 5.3 1208.56
gpt-oss-KV8-fuse-q-quant 2 1000 200 4.98 78.89
gpt-oss-KV8-skip-and-fuse-q-quant 1,2 25000 200 5.23 1209.68
gpt-oss-KV8-skip-and-fuse-q-quant 1,2 1000 200 4.91 76.06
gpt-oss-KV8-skip-and-fuse-q-quant-tuned-fa3tiles 1,2,3 25000 200 5.06 1214.72
gpt-oss-KV8-skip-and-fuse-q-quant-tuned-fa3tiles 1,2,3 1000 200 4.86 79.82

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions