Skip to content
5 changes: 3 additions & 2 deletions cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ if(VLLM_FLASH_ATTN_SRC_DIR)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
# FIXME(mseznec): replace with vllm-project once PR is merged
GIT_REPOSITORY https://github.com/mickaelseznec/flash-attention.git
GIT_TAG 38843737dc9b9f27a054cc73bee224a7f8e928bf
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
80 changes: 72 additions & 8 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
Expand Down Expand Up @@ -85,6 +86,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
use_out: bool,
Expand All @@ -97,11 +99,15 @@ def test_flash_attn_with_paged_kv(
num_blocks: int,
sliding_window: Optional[int],
fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")

current_platform.seed_everything(0)
num_seqs = len(kv_lens)
Expand Down Expand Up @@ -130,10 +136,30 @@ def test_flash_attn_with_paged_kv(

q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None

maybe_quantized_query = q
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
q_scale = q.amax().to(torch.float32) / 448.0
k_scale = key_cache.amax().to(torch.float32) / 448.0
v_scale = value_cache.amax().to(torch.float32) / 448.0

maybe_quantized_query = (q / q_scale).to(q_dtype)
maybe_quantized_key_cache = (key_cache / k_scale).to(q_dtype)
maybe_quantized_value_cache = (value_cache / k_scale).to(q_dtype)

q_descale = q_scale.expand((num_seqs, num_kv_heads))
k_descale = k_scale.expand((num_seqs, num_kv_heads))
v_descale = v_scale.expand((num_seqs, num_kv_heads))

output = flash_attn_with_kvcache(
q=q,
k_cache=key_cache,
v_cache=value_cache,
q=maybe_quantized_query,
k_cache=maybe_quantized_key_cache,
v_cache=maybe_quantized_value_cache,
out=out,
softmax_scale=scale,
causal=True,
Expand All @@ -142,10 +168,17 @@ def test_flash_attn_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
output = output if not use_out else out
output = output.squeeze(1)

atol, rtol = 1e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1e-1, 1.5e-1

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
Expand All @@ -155,7 +188,7 @@ def test_flash_attn_with_paged_kv(
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand All @@ -171,6 +204,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("fa_version", [2, 3])
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
def test_varlen_with_paged_kv(
use_out: bool,
Expand All @@ -183,11 +217,15 @@ def test_varlen_with_paged_kv(
soft_cap: Optional[float],
num_blocks: int,
fa_version: int,
q_dtype: Optional[torch.dtype],
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
Expand Down Expand Up @@ -223,10 +261,30 @@ def test_varlen_with_paged_kv(
dtype=torch.int32)

out = torch.empty_like(query) if use_out else None

maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
q_scale = query.amax().to(torch.float32) / 448.0
k_scale = key_cache.amax().to(torch.float32) / 448.0
v_scale = value_cache.amax().to(torch.float32) / 448.0

maybe_quantized_query = (query / q_scale).to(q_dtype)
maybe_quantized_key_cache = (key_cache / k_scale).to(q_dtype)
maybe_quantized_value_cache = (value_cache / k_scale).to(q_dtype)

q_descale = q_scale.expand((num_seqs, num_kv_heads))
k_descale = k_scale.expand((num_seqs, num_kv_heads))
v_descale = v_scale.expand((num_seqs, num_kv_heads))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we maybe test per-head scales here too?, i.e. also test with non-zero strides

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add tests here, but these type of scaling isn't supported by vLLM for the moment. I believe that whenever we add support for it, we can add tests as well.
Besides, there's already a combination of 9k tests in here, I don't want to make the duration explode if it's not 100% needed :D


output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
v=maybe_quantized_value_cache,
out=out,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens,
Expand All @@ -238,6 +296,9 @@ def test_varlen_with_paged_kv(
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
fa_version=fa_version,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
output = output if not use_out else out

Expand All @@ -252,5 +313,8 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window,
soft_cap=soft_cap,
)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
atol, rtol = 1e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],

class AttentionLayer(Protocol):

_q_scale: torch.Tensor
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_k_scale_float: float
Expand Down
64 changes: 60 additions & 4 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,19 @@ def forward(
for profiling run.
attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
"key/v_scale is not supported in FlashAttention.")

assert output is not None, "Output tensor must be provided."

# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "
"base dtype bfloat16")

attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
Expand All @@ -694,6 +700,7 @@ def forward(
window_size = self.sliding_window
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8")

if kv_cache.numel() > 0:
key_cache = kv_cache[0]
Expand Down Expand Up @@ -729,6 +736,19 @@ def forward(
layer._v_scale,
)

if fp8_attention:
kv_cache = kv_cache.view(torch.float8_e4m3fn)
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)

if fp8_attention:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))

(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
Expand All @@ -753,6 +773,23 @@ def forward(
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]

if fp8_attention:
num_kv_tokens, num_kv_heads, head_size = key.shape

key, _ = ops.scaled_fp8_quant(
key.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._k_scale)
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))

value, _ = ops.scaled_fp8_quant(
value.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._v_scale)
value = value.reshape(
(num_kv_tokens, num_kv_heads, head_size))

descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query,
k=key,
Expand All @@ -768,13 +805,19 @@ def forward(
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
assert prefill_meta.query_start_loc is not None
max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
Expand All @@ -791,6 +834,9 @@ def forward(
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)

if decode_meta := attn_metadata.decode_metadata:
Expand All @@ -804,6 +850,9 @@ def forward(
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
assert decode_meta.query_start_loc is not None
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
Expand All @@ -820,6 +869,9 @@ def forward(
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
# Use flash_attn_with_kvcache for normal decoding.
Expand All @@ -828,6 +880,7 @@ def forward(
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
descale_shape = (seq_lens_arg.shape[0], key.shape[1])
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
Expand All @@ -841,6 +894,9 @@ def forward(
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output

Expand Down
9 changes: 7 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __init__(
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)

# We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer)
Expand Down Expand Up @@ -153,6 +156,7 @@ def __init__(
).parallel_config.pipeline_parallel_size)
]

self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

Expand All @@ -178,7 +182,7 @@ def forward(
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
Expand Down Expand Up @@ -225,7 +229,8 @@ def forward(
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)

def calc_kv_scales(self, key, value):
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._k_scale_float = self._k_scale.item()
Expand Down
7 changes: 6 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False
Expand Down Expand Up @@ -521,13 +522,17 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),

# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
# Divisor for dynamic key scale factor calculation for FP8 KV Cache
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),

# Divisor for dynamic value scale factor calculation for FP8 KV Cache
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))),
Expand Down
Loading