Skip to content
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5ea3788
sync with amd, support v1
LucasWilkinson Feb 28, 2025
d048631
fix IMA
LucasWilkinson Mar 4, 2025
e09841c
bugfix
LucasWilkinson Mar 4, 2025
52e7234
working
LucasWilkinson Mar 4, 2025
4604201
cleanup
LucasWilkinson Mar 5, 2025
f9f3e3e
fa MLA
LucasWilkinson Mar 5, 2025
b3b060b
commit wip
LucasWilkinson Mar 6, 2025
e286de8
cleanup
LucasWilkinson Apr 18, 2025
27a2cd2
fix
LucasWilkinson Apr 20, 2025
9165af3
move files
LucasWilkinson Apr 20, 2025
d056efd
fix up
LucasWilkinson Apr 20, 2025
ac4c624
v0 support + decode threshold
LucasWilkinson Apr 22, 2025
1f6bb3d
v0 fix
LucasWilkinson Apr 22, 2025
73c8736
fix
LucasWilkinson Apr 22, 2025
0f9ed95
fix logs
LucasWilkinson Apr 22, 2025
f2dc4a3
don't schedule prefills
LucasWilkinson Apr 22, 2025
d695fdc
still default to FlashMLA
LucasWilkinson Apr 24, 2025
82c9393
Remove V0 FlashAttention MLA
MatthewBonanni Aug 21, 2025
dc16bb5
Move back to original location
MatthewBonanni Aug 21, 2025
046af0b
Undo change
MatthewBonanni Aug 21, 2025
8a0fe94
Match main
MatthewBonanni Aug 21, 2025
9c5445d
Use reorder_batch_threshold throughout
MatthewBonanni Aug 21, 2025
6b90fd7
Remove input_positions
MatthewBonanni Aug 21, 2025
790bde6
Match main, remove unused arguments
MatthewBonanni Aug 21, 2025
161f50e
Align _build_decode signature
MatthewBonanni Aug 21, 2025
d87b921
Fix more arguments
MatthewBonanni Aug 21, 2025
5a32eeb
More compatibility fixes
MatthewBonanni Aug 21, 2025
63ec527
Fix backend enum
MatthewBonanni Aug 21, 2025
da96e28
Remove unused helpers
MatthewBonanni Aug 21, 2025
fb09124
Rename
MatthewBonanni Aug 25, 2025
70343e7
Loosen tolerances for FA MLA backend
MatthewBonanni Aug 25, 2025
5daadfe
Fix _forward_decode signature
MatthewBonanni Aug 25, 2025
b8e6e0a
Respect each backend's decode threshold
MatthewBonanni Aug 26, 2025
91f01d4
Fix backend selection logic
MatthewBonanni Aug 26, 2025
4201218
Address pre-commit
MatthewBonanni Aug 26, 2025
fe5ba41
Update GIT_TAG
MatthewBonanni Aug 27, 2025
6455578
Decode threshold tuning
MatthewBonanni Aug 27, 2025
513fdeb
Undo V0 change
MatthewBonanni Aug 27, 2025
fd25615
Pass qkv_dtype
MatthewBonanni Aug 27, 2025
398e55b
increase wheel size
LucasWilkinson Aug 28, 2025
4f29ce1
missing line
LucasWilkinson Aug 28, 2025
8672a7f
Fix backend selector logic and test
MatthewBonanni Aug 28, 2025
ea0f9c4
Merge remote-tracking branch 'origin/main' into lwilkinson/fa-mla
LucasWilkinson Aug 29, 2025
8298b9e
Merge remote-tracking branch 'origin/main' into lwilkinson/fa-mla
LucasWilkinson Aug 29, 2025
15c0fed
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Aug 29, 2025
98f3592
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Aug 29, 2025
6ef55b0
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Sep 2, 2025
84737e7
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Sep 3, 2025
dd2516a
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
16 changes: 0 additions & 16 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ def _convert_dtype_to_torch(dtype):
}


def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
2, # K and V
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache


def create_and_prepopulate_kv_cache(
k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor],
Expand Down
202 changes: 101 additions & 101 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
_Backend.TRITON_MLA_VLLM_V1
_Backend.FLASH_ATTN_MLA_VLLM_V1, _Backend.TRITON_MLA_VLLM_V1
]

# Remove CUTLASS_MLA from the list if not using sm100
Expand Down Expand Up @@ -69,20 +69,6 @@ def _convert_dtype_to_torch(dtype):
}


def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.head_size, # latent dimension
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache


def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
Expand Down Expand Up @@ -315,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):

# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
all_sdpa_outputs = []
all_sdpa_outputs: list[list[torch.Tensor]] = []
kv_c_contexts, k_pe_contexts = [], []

# Create shared MLA weight matrices for consistency across all sequences
Expand All @@ -331,6 +317,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device=device)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)

for i, backend in enumerate(BACKENDS_TO_TEST):
all_sdpa_outputs.append([])

for i in range(batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
Expand Down Expand Up @@ -358,85 +347,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype=dtype,
device=device)

# Determine if this is decode (single token)
# or prefill (multiple tokens)
is_decode = q_len == 1
# Determine if this is decode or prefill
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)

# Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)

if is_decode:
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]

# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)

# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)

sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]

# Project back to output space: sdpa_out @ W_UV
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
else:
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)

# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)

# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len,
s_len,
dtype=torch.bool,
device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask

# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)

# Single attention call with custom mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
attn_mask=attn_mask,
scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)

all_sdpa_outputs.append(sdpa_out_i)
#######################################################
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]

# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)

# Create custom attention mask for decode path:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their position
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask

# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)

sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]

# Project back to output space: sdpa_out @ W_UV
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
W_UV)
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)

#######################################################
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)

# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)

# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask

# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)

# Single attention call with custom mask
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)

for i, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]:
all_sdpa_outputs[i].append(sdpa_out_i_decode)
else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill)

# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c)
Expand All @@ -451,7 +448,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
sdpa_outputs = []
for i, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))

# Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear
Expand Down Expand Up @@ -486,20 +485,20 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks=True)

# 4. Run vLLM backends and compare
for backend_name in BACKENDS_TO_TEST:
for i, backend_name in enumerate(BACKENDS_TO_TEST):
backend_output = run_attention_backend(
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
mock_kv_b_proj)

# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
assert backend_output.shape == sdpa_outputs[i].shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}")
assert backend_output.dtype == sdpa_output.dtype, (
f"SDPA shape {sdpa_outputs[i].shape}")
assert backend_output.dtype == sdpa_outputs[i].dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}")
f"SDPA dtype {sdpa_outputs[i].dtype}")

assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values")
Expand All @@ -508,12 +507,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-1

max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_diff = torch.max(torch.abs(backend_output -
sdpa_outputs[i])).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
torch.abs(backend_output - sdpa_outputs[i]) /
torch.abs(sdpa_outputs[i])).item()
all_close = torch.allclose(backend_output,
sdpa_output,
sdpa_outputs[i],
rtol=rtol,
atol=atol)

Expand Down
2 changes: 2 additions & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}
Expand Down
13 changes: 13 additions & 0 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,18 @@ def flash_attn_supports_fp8() -> bool:
current_platform.get_device_capability().major == 9


def flash_attn_supports_mla():
from vllm.platforms import current_platform
if current_platform.is_cuda():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported)
return is_fa_version_supported(3) \
and current_platform.get_device_capability()[0] == 9
except (ImportError, AssertionError):
pass
return False


def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu()
3 changes: 3 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,9 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"TRITON_MLA",
"CUTLASS_MLA",
"FLASHMLA",
"FLASHMLA_VLLM_V1",
"FLASH_ATTN_MLA",
"FLASH_ATTN_MLA_VLLM_V1",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"ROCM_AITER_MLA",
Expand Down
1 change: 1 addition & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def get_vllm_port() -> Optional[int]:
# - "ROCM_FLASH": use ROCmFlashAttention
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

Expand Down
Loading