Skip to content

Commit a597a57

Browse files
[Attention] Flash Attention 3 - fp8 (#14570)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
1 parent ae65f3e commit a597a57

File tree

15 files changed

+272
-76
lines changed

15 files changed

+272
-76
lines changed

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
41+
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

tests/kernels/test_flash_attn.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
HEAD_SIZES = [128, 256]
1616
BLOCK_SIZES = [16, 32]
1717
DTYPES = [torch.float16, torch.bfloat16]
18+
QDTYPES = [None, torch.float8_e4m3fn]
1819
# one value large enough to test overflow in index calculation.
1920
# one value small enough to test the schema op check
2021
NUM_BLOCKS = [32768, 2048]
@@ -85,6 +86,7 @@ def ref_paged_attn(
8586
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
8687
@pytest.mark.parametrize("sliding_window", [None, 256])
8788
@pytest.mark.parametrize("fa_version", [2, 3])
89+
@pytest.mark.parametrize("q_dtype", QDTYPES)
8890
@torch.inference_mode()
8991
def test_flash_attn_with_paged_kv(
9092
use_out: bool,
@@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
9799
num_blocks: int,
98100
sliding_window: Optional[int],
99101
fa_version: int,
102+
q_dtype: Optional[torch.dtype],
100103
) -> None:
101104
torch.set_default_device("cuda")
102105
if not is_fa_version_supported(fa_version):
103106
pytest.skip(f"Flash attention version {fa_version} not supported due "
104107
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
108+
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
109+
pytest.skip("Flash attention with quantized inputs is only "
110+
"supported on version 3 with bfloat16 base type")
105111

106112
current_platform.seed_everything(0)
107113
num_seqs = len(kv_lens)
@@ -130,10 +136,28 @@ def test_flash_attn_with_paged_kv(
130136

131137
q = query.unsqueeze(1)
132138
out = torch.empty_like(q) if use_out else None
139+
140+
maybe_quantized_query = q
141+
maybe_quantized_key_cache = key_cache
142+
maybe_quantized_value_cache = value_cache
143+
q_descale = None
144+
k_descale = None
145+
v_descale = None
146+
if q_dtype is not None:
147+
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
148+
maybe_quantized_query = query.to(q_dtype)
149+
maybe_quantized_key_cache = key_cache.to(q_dtype)
150+
maybe_quantized_value_cache = value_cache.to(q_dtype)
151+
152+
scale_shape = (num_seqs, num_kv_heads)
153+
q_descale = torch.ones(scale_shape, dtype=torch.float32)
154+
k_descale = torch.ones(scale_shape, dtype=torch.float32)
155+
v_descale = torch.ones(scale_shape, dtype=torch.float32)
156+
133157
output = flash_attn_with_kvcache(
134-
q=q,
135-
k_cache=key_cache,
136-
v_cache=value_cache,
158+
q=maybe_quantized_query,
159+
k_cache=maybe_quantized_key_cache,
160+
v_cache=maybe_quantized_value_cache,
137161
out=out,
138162
softmax_scale=scale,
139163
causal=True,
@@ -142,10 +166,17 @@ def test_flash_attn_with_paged_kv(
142166
softcap=soft_cap if soft_cap is not None else 0,
143167
window_size=window_size,
144168
fa_version=fa_version,
169+
q_descale=q_descale,
170+
k_descale=k_descale,
171+
v_descale=v_descale,
145172
)
146173
output = output if not use_out else out
147174
output = output.squeeze(1)
148175

176+
atol, rtol = 1.5e-2, 1e-2
177+
if q_dtype is not None:
178+
atol, rtol = 1.5e-1, 1.5e-1
179+
149180
ref_output = ref_paged_attn(query=query,
150181
key_cache=key_cache,
151182
value_cache=value_cache,
@@ -155,7 +186,7 @@ def test_flash_attn_with_paged_kv(
155186
scale=scale,
156187
soft_cap=soft_cap,
157188
sliding_window=sliding_window)
158-
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
189+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
159190
f"{torch.max(torch.abs(output - ref_output))}"
160191

161192

@@ -171,6 +202,7 @@ def test_flash_attn_with_paged_kv(
171202
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
172203
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
173204
@pytest.mark.parametrize("fa_version", [2, 3])
205+
@pytest.mark.parametrize("q_dtype", QDTYPES)
174206
@torch.inference_mode()
175207
def test_varlen_with_paged_kv(
176208
use_out: bool,
@@ -183,11 +215,15 @@ def test_varlen_with_paged_kv(
183215
soft_cap: Optional[float],
184216
num_blocks: int,
185217
fa_version: int,
218+
q_dtype: Optional[torch.dtype],
186219
) -> None:
187220
torch.set_default_device("cuda")
188221
if not is_fa_version_supported(fa_version):
189222
pytest.skip(f"Flash attention version {fa_version} not supported due "
190223
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
224+
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
225+
pytest.skip("Flash attention with quantized inputs is only "
226+
"supported on version 3 with bfloat16 base type")
191227
current_platform.seed_everything(0)
192228
num_seqs = len(seq_lens)
193229
query_lens = [x[0] for x in seq_lens]
@@ -223,10 +259,28 @@ def test_varlen_with_paged_kv(
223259
dtype=torch.int32)
224260

225261
out = torch.empty_like(query) if use_out else None
262+
263+
maybe_quantized_query = query
264+
maybe_quantized_key_cache = key_cache
265+
maybe_quantized_value_cache = value_cache
266+
q_descale = None
267+
k_descale = None
268+
v_descale = None
269+
if q_dtype is not None:
270+
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
271+
maybe_quantized_query = query.to(q_dtype)
272+
maybe_quantized_key_cache = key_cache.to(q_dtype)
273+
maybe_quantized_value_cache = value_cache.to(q_dtype)
274+
275+
scale_shape = (num_seqs, num_kv_heads)
276+
q_descale = torch.ones(scale_shape, dtype=torch.float32)
277+
k_descale = torch.ones(scale_shape, dtype=torch.float32)
278+
v_descale = torch.ones(scale_shape, dtype=torch.float32)
279+
226280
output = flash_attn_varlen_func(
227-
q=query,
228-
k=key_cache,
229-
v=value_cache,
281+
q=maybe_quantized_query,
282+
k=maybe_quantized_key_cache,
283+
v=maybe_quantized_value_cache,
230284
out=out,
231285
cu_seqlens_q=cu_query_lens,
232286
seqused_k=kv_lens,
@@ -238,6 +292,9 @@ def test_varlen_with_paged_kv(
238292
block_table=block_tables,
239293
softcap=soft_cap if soft_cap is not None else 0,
240294
fa_version=fa_version,
295+
q_descale=q_descale,
296+
k_descale=k_descale,
297+
v_descale=v_descale,
241298
)
242299
output = output if not use_out else out
243300

@@ -252,5 +309,8 @@ def test_varlen_with_paged_kv(
252309
sliding_window=sliding_window,
253310
soft_cap=soft_cap,
254311
)
255-
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
312+
atol, rtol = 1.5e-2, 1e-2
313+
if q_dtype is not None:
314+
atol, rtol = 1.5e-1, 1.5e-1
315+
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
256316
f"{torch.max(torch.abs(output - ref_output))}"

vllm/attention/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
AttentionMetadata,
55
AttentionMetadataBuilder,
66
AttentionState, AttentionType)
7-
from vllm.attention.backends.utils import get_flash_attn_version
87
from vllm.attention.layer import Attention
98
from vllm.attention.selector import get_attn_backend
109

1110
__all__ = [
12-
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
13-
"AttentionMetadataBuilder", "Attention", "AttentionState",
14-
"get_attn_backend", "get_flash_attn_version"
11+
"Attention",
12+
"AttentionBackend",
13+
"AttentionMetadata",
14+
"AttentionType",
15+
"AttentionMetadataBuilder",
16+
"Attention",
17+
"AttentionState",
18+
"get_attn_backend",
1519
]

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
232232

233233
class AttentionLayer(Protocol):
234234

235+
_q_scale: torch.Tensor
235236
_k_scale: torch.Tensor
236237
_v_scale: torch.Tensor
237238
_k_scale_float: float

vllm/attention/backends/flash_attn.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
# yapf: enable
2020
from vllm.attention.backends.utils import (
2121
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
22-
compute_slot_mapping_start_idx, get_flash_attn_version,
23-
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
24-
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set,
25-
is_block_tables_empty)
22+
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
23+
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
24+
is_all_encoder_attn_metadata_set, is_block_tables_empty)
25+
from vllm.fa_utils import get_flash_attn_version
2626
from vllm.logger import init_logger
2727
from vllm.multimodal import MultiModalPlaceholderMap
2828
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@@ -630,9 +630,11 @@ def __init__(
630630
self.sliding_window = ((sliding_window - 1,
631631
0) if sliding_window is not None else (-1, -1))
632632
self.kv_cache_dtype = kv_cache_dtype
633-
if is_quantized_kv_cache(self.kv_cache_dtype):
633+
self.vllm_flash_attn_version = get_flash_attn_version()
634+
if (is_quantized_kv_cache(self.kv_cache_dtype)
635+
and self.vllm_flash_attn_version != 3):
634636
raise NotImplementedError(
635-
"FlashAttention with FP8 KV cache not yet supported")
637+
"Only FlashAttention3 supports FP8 KV cache")
636638
if logits_soft_cap is None:
637639
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
638640
logits_soft_cap = 0
@@ -647,7 +649,6 @@ def __init__(
647649
f"Head size {head_size} is not supported by FlashAttention. "
648650
f"Supported head sizes are: {support_head_sizes}.")
649651
self.attn_type = attn_type
650-
self.vllm_flash_attn_version = get_flash_attn_version()
651652

652653
def forward(
653654
self,
@@ -671,13 +672,19 @@ def forward(
671672
for profiling run.
672673
attn_metadata: Metadata for attention.
673674
NOTE: It in-place updates the output tensor.
675+
NOTE: FP8 quantization, flash-attn expect the size of
676+
{q,k,v}_descale to be (num_sequences, num_kv_heads).
677+
We use torch's .expand() to avoid duplicating values
674678
"""
675-
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
676-
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
677-
"key/v_scale is not supported in FlashAttention.")
678-
679679
assert output is not None, "Output tensor must be provided."
680680

681+
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
682+
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
683+
assert (
684+
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
685+
"key/v_scale is only supported in FlashAttention 3 with "
686+
"base dtype bfloat16")
687+
681688
attn_type = self.attn_type
682689
if (attn_type == AttentionType.ENCODER
683690
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
@@ -694,6 +701,7 @@ def forward(
694701
window_size = self.sliding_window
695702
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
696703
logits_soft_cap: Optional[float] = self.logits_soft_cap
704+
fp8_attention = kv_cache_dtype.startswith("fp8")
697705

698706
if kv_cache.numel() > 0:
699707
key_cache = kv_cache[0]
@@ -729,6 +737,19 @@ def forward(
729737
layer._v_scale,
730738
)
731739

740+
if fp8_attention:
741+
kv_cache = kv_cache.view(torch.float8_e4m3fn)
742+
key_cache = key_cache.view(torch.float8_e4m3fn)
743+
value_cache = value_cache.view(torch.float8_e4m3fn)
744+
745+
if fp8_attention:
746+
num_tokens, num_heads, head_size = query.shape
747+
query, _ = ops.scaled_fp8_quant(
748+
query.reshape(
749+
(num_tokens, num_heads * head_size)).contiguous(),
750+
layer._q_scale)
751+
query = query.reshape((num_tokens, num_heads, head_size))
752+
732753
(num_prefill_query_tokens, num_prefill_kv_tokens,
733754
num_decode_query_tokens) = \
734755
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
@@ -753,6 +774,23 @@ def forward(
753774
key = key[:num_prefill_kv_tokens]
754775
value = value[:num_prefill_kv_tokens]
755776

777+
if fp8_attention:
778+
num_kv_tokens, num_kv_heads, head_size = key.shape
779+
780+
key, _ = ops.scaled_fp8_quant(
781+
key.reshape((num_kv_tokens,
782+
num_kv_heads * head_size)).contiguous(),
783+
layer._k_scale)
784+
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
785+
786+
value, _ = ops.scaled_fp8_quant(
787+
value.reshape((num_kv_tokens,
788+
num_kv_heads * head_size)).contiguous(),
789+
layer._v_scale)
790+
value = value.reshape(
791+
(num_kv_tokens, num_kv_heads, head_size))
792+
793+
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
756794
flash_attn_varlen_func(
757795
q=query,
758796
k=key,
@@ -768,13 +806,19 @@ def forward(
768806
softcap=logits_soft_cap,
769807
out=prefill_output,
770808
fa_version=self.vllm_flash_attn_version,
809+
q_descale=layer._q_scale.expand(descale_shape),
810+
k_descale=layer._k_scale.expand(descale_shape),
811+
v_descale=layer._v_scale.expand(descale_shape),
771812
)
772813
else:
773814
# prefix-enabled attention
774815
assert attn_type == AttentionType.DECODER, (
775816
"Only decoder-only models support prefix caching")
776817
assert prefill_meta.seq_lens is not None
818+
assert prefill_meta.query_start_loc is not None
777819
max_seq_len = max(prefill_meta.seq_lens)
820+
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
821+
key.shape[1])
778822
flash_attn_varlen_func( # noqa
779823
q=query,
780824
k=key_cache,
@@ -791,6 +835,9 @@ def forward(
791835
softcap=logits_soft_cap,
792836
out=prefill_output,
793837
fa_version=self.vllm_flash_attn_version,
838+
q_descale=layer._q_scale.expand(descale_shape),
839+
k_descale=layer._k_scale.expand(descale_shape),
840+
v_descale=layer._v_scale.expand(descale_shape),
794841
)
795842

796843
if decode_meta := attn_metadata.decode_metadata:
@@ -804,6 +851,9 @@ def forward(
804851
assert attn_type == AttentionType.DECODER, (
805852
"Only decoder-only models support max_decode_query_len > 1"
806853
)
854+
assert decode_meta.query_start_loc is not None
855+
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
856+
key.shape[1])
807857
flash_attn_varlen_func(
808858
q=decode_query,
809859
k=key_cache,
@@ -820,6 +870,9 @@ def forward(
820870
block_table=decode_meta.block_tables,
821871
out=decode_output,
822872
fa_version=self.vllm_flash_attn_version,
873+
q_descale=layer._q_scale.expand(descale_shape),
874+
k_descale=layer._k_scale.expand(descale_shape),
875+
v_descale=layer._v_scale.expand(descale_shape),
823876
)
824877
else:
825878
# Use flash_attn_with_kvcache for normal decoding.
@@ -828,6 +881,7 @@ def forward(
828881
_,
829882
block_tables_arg,
830883
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
884+
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
831885
flash_attn_with_kvcache(
832886
q=decode_query.unsqueeze(1),
833887
k_cache=key_cache,
@@ -841,6 +895,9 @@ def forward(
841895
softcap=logits_soft_cap,
842896
out=decode_output.unsqueeze(1),
843897
fa_version=self.vllm_flash_attn_version,
898+
q_descale=layer._q_scale.expand(descale_shape),
899+
k_descale=layer._k_scale.expand(descale_shape),
900+
v_descale=layer._v_scale.expand(descale_shape),
844901
)
845902
return output
846903

vllm/attention/backends/mla/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@
203203
AttentionState, MLAAttentionImpl)
204204
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
205205
compute_slot_mapping_start_idx,
206-
get_flash_attn_version,
207206
is_block_tables_empty)
208207
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
208+
from vllm.fa_utils import get_flash_attn_version
209209
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
210210
LinearBase, RowParallelLinear,
211211
UnquantizedLinearMethod)

0 commit comments

Comments
 (0)