1919# yapf: enable
2020from vllm .attention .backends .utils import (
2121 PAD_SLOT_ID , CommonAttentionState , compute_slot_mapping ,
22- compute_slot_mapping_start_idx , get_flash_attn_version ,
23- get_num_prefill_decode_query_kv_tokens , get_seq_len_block_table_args ,
24- is_all_cross_attn_metadata_set , is_all_encoder_attn_metadata_set ,
25- is_block_tables_empty )
22+ compute_slot_mapping_start_idx , get_num_prefill_decode_query_kv_tokens ,
23+ get_seq_len_block_table_args , is_all_cross_attn_metadata_set ,
24+ is_all_encoder_attn_metadata_set , is_block_tables_empty )
25+ from vllm . fa_utils import get_flash_attn_version
2626from vllm .logger import init_logger
2727from vllm .multimodal import MultiModalPlaceholderMap
2828from vllm .utils import async_tensor_h2d , make_tensor_with_pad
@@ -630,9 +630,11 @@ def __init__(
630630 self .sliding_window = ((sliding_window - 1 ,
631631 0 ) if sliding_window is not None else (- 1 , - 1 ))
632632 self .kv_cache_dtype = kv_cache_dtype
633- if is_quantized_kv_cache (self .kv_cache_dtype ):
633+ self .vllm_flash_attn_version = get_flash_attn_version ()
634+ if (is_quantized_kv_cache (self .kv_cache_dtype )
635+ and self .vllm_flash_attn_version != 3 ):
634636 raise NotImplementedError (
635- "FlashAttention with FP8 KV cache not yet supported " )
637+ "Only FlashAttention3 supports FP8 KV cache" )
636638 if logits_soft_cap is None :
637639 # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
638640 logits_soft_cap = 0
@@ -647,7 +649,6 @@ def __init__(
647649 f"Head size { head_size } is not supported by FlashAttention. "
648650 f"Supported head sizes are: { support_head_sizes } ." )
649651 self .attn_type = attn_type
650- self .vllm_flash_attn_version = get_flash_attn_version ()
651652
652653 def forward (
653654 self ,
@@ -671,13 +672,19 @@ def forward(
671672 for profiling run.
672673 attn_metadata: Metadata for attention.
673674 NOTE: It in-place updates the output tensor.
675+ NOTE: FP8 quantization, flash-attn expect the size of
676+ {q,k,v}_descale to be (num_sequences, num_kv_heads).
677+ We use torch's .expand() to avoid duplicating values
674678 """
675- # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
676- assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0 , (
677- "key/v_scale is not supported in FlashAttention." )
678-
679679 assert output is not None , "Output tensor must be provided."
680680
681+ # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
682+ if self .vllm_flash_attn_version < 3 or output .dtype != torch .bfloat16 :
683+ assert (
684+ layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0 ), (
685+ "key/v_scale is only supported in FlashAttention 3 with "
686+ "base dtype bfloat16" )
687+
681688 attn_type = self .attn_type
682689 if (attn_type == AttentionType .ENCODER
683690 and (not attn_metadata .is_all_encoder_attn_metadata_set )):
@@ -694,6 +701,7 @@ def forward(
694701 window_size = self .sliding_window
695702 alibi_slopes : Optional [torch .Tensor ] = self .alibi_slopes
696703 logits_soft_cap : Optional [float ] = self .logits_soft_cap
704+ fp8_attention = kv_cache_dtype .startswith ("fp8" )
697705
698706 if kv_cache .numel () > 0 :
699707 key_cache = kv_cache [0 ]
@@ -729,6 +737,19 @@ def forward(
729737 layer ._v_scale ,
730738 )
731739
740+ if fp8_attention :
741+ kv_cache = kv_cache .view (torch .float8_e4m3fn )
742+ key_cache = key_cache .view (torch .float8_e4m3fn )
743+ value_cache = value_cache .view (torch .float8_e4m3fn )
744+
745+ if fp8_attention :
746+ num_tokens , num_heads , head_size = query .shape
747+ query , _ = ops .scaled_fp8_quant (
748+ query .reshape (
749+ (num_tokens , num_heads * head_size )).contiguous (),
750+ layer ._q_scale )
751+ query = query .reshape ((num_tokens , num_heads , head_size ))
752+
732753 (num_prefill_query_tokens , num_prefill_kv_tokens ,
733754 num_decode_query_tokens ) = \
734755 get_num_prefill_decode_query_kv_tokens (attn_metadata , attn_type )
@@ -753,6 +774,23 @@ def forward(
753774 key = key [:num_prefill_kv_tokens ]
754775 value = value [:num_prefill_kv_tokens ]
755776
777+ if fp8_attention :
778+ num_kv_tokens , num_kv_heads , head_size = key .shape
779+
780+ key , _ = ops .scaled_fp8_quant (
781+ key .reshape ((num_kv_tokens ,
782+ num_kv_heads * head_size )).contiguous (),
783+ layer ._k_scale )
784+ key = key .reshape ((num_kv_tokens , num_kv_heads , head_size ))
785+
786+ value , _ = ops .scaled_fp8_quant (
787+ value .reshape ((num_kv_tokens ,
788+ num_kv_heads * head_size )).contiguous (),
789+ layer ._v_scale )
790+ value = value .reshape (
791+ (num_kv_tokens , num_kv_heads , head_size ))
792+
793+ descale_shape = (q_seq_start_loc .shape [0 ] - 1 , key .shape [1 ])
756794 flash_attn_varlen_func (
757795 q = query ,
758796 k = key ,
@@ -768,13 +806,19 @@ def forward(
768806 softcap = logits_soft_cap ,
769807 out = prefill_output ,
770808 fa_version = self .vllm_flash_attn_version ,
809+ q_descale = layer ._q_scale .expand (descale_shape ),
810+ k_descale = layer ._k_scale .expand (descale_shape ),
811+ v_descale = layer ._v_scale .expand (descale_shape ),
771812 )
772813 else :
773814 # prefix-enabled attention
774815 assert attn_type == AttentionType .DECODER , (
775816 "Only decoder-only models support prefix caching" )
776817 assert prefill_meta .seq_lens is not None
818+ assert prefill_meta .query_start_loc is not None
777819 max_seq_len = max (prefill_meta .seq_lens )
820+ descale_shape = (prefill_meta .query_start_loc .shape [0 ] - 1 ,
821+ key .shape [1 ])
778822 flash_attn_varlen_func ( # noqa
779823 q = query ,
780824 k = key_cache ,
@@ -791,6 +835,9 @@ def forward(
791835 softcap = logits_soft_cap ,
792836 out = prefill_output ,
793837 fa_version = self .vllm_flash_attn_version ,
838+ q_descale = layer ._q_scale .expand (descale_shape ),
839+ k_descale = layer ._k_scale .expand (descale_shape ),
840+ v_descale = layer ._v_scale .expand (descale_shape ),
794841 )
795842
796843 if decode_meta := attn_metadata .decode_metadata :
@@ -804,6 +851,9 @@ def forward(
804851 assert attn_type == AttentionType .DECODER , (
805852 "Only decoder-only models support max_decode_query_len > 1"
806853 )
854+ assert decode_meta .query_start_loc is not None
855+ descale_shape = (decode_meta .query_start_loc .shape [0 ] - 1 ,
856+ key .shape [1 ])
807857 flash_attn_varlen_func (
808858 q = decode_query ,
809859 k = key_cache ,
@@ -820,6 +870,9 @@ def forward(
820870 block_table = decode_meta .block_tables ,
821871 out = decode_output ,
822872 fa_version = self .vllm_flash_attn_version ,
873+ q_descale = layer ._q_scale .expand (descale_shape ),
874+ k_descale = layer ._k_scale .expand (descale_shape ),
875+ v_descale = layer ._v_scale .expand (descale_shape ),
823876 )
824877 else :
825878 # Use flash_attn_with_kvcache for normal decoding.
@@ -828,6 +881,7 @@ def forward(
828881 _ ,
829882 block_tables_arg ,
830883 ) = get_seq_len_block_table_args (decode_meta , False , attn_type )
884+ descale_shape = (seq_lens_arg .shape [0 ], key_cache .shape [- 2 ])
831885 flash_attn_with_kvcache (
832886 q = decode_query .unsqueeze (1 ),
833887 k_cache = key_cache ,
@@ -841,6 +895,9 @@ def forward(
841895 softcap = logits_soft_cap ,
842896 out = decode_output .unsqueeze (1 ),
843897 fa_version = self .vllm_flash_attn_version ,
898+ q_descale = layer ._q_scale .expand (descale_shape ),
899+ k_descale = layer ._k_scale .expand (descale_shape ),
900+ v_descale = layer ._v_scale .expand (descale_shape ),
844901 )
845902 return output
846903
0 commit comments