Skip to content

Commit 66c6e09

Browse files
gshtrassouthfreebird
authored andcommitted
[ROCm] Split AITER unified attention into its own backend (vllm-project#25507)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent 26b5ec8 commit 66c6e09

File tree

8 files changed

+325
-301
lines changed

8 files changed

+325
-301
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 54 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch._dynamo
88

99
from tests.compile.backend import LazyInitPass, TestBackend
10-
from tests.models.utils import check_outputs_equal
1110
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
12-
from vllm import LLM, SamplingParams
1311
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
1412
from vllm.attention import Attention, AttentionMetadata
1513
from vllm.attention.backends.registry import _Backend
@@ -31,7 +29,6 @@
3129
)
3230
from vllm.forward_context import get_forward_context, set_forward_context
3331
from vllm.model_executor.layers.quantization.utils.quant_utils import (
34-
QuantKey,
3532
kFp8StaticTensorSym,
3633
kNvfp4Quant,
3734
)
@@ -48,132 +45,6 @@
4845
backend_unfused: Optional[TestBackend] = None
4946

5047

51-
@pytest.mark.parametrize(
52-
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
53-
)
54-
@pytest.mark.parametrize("use_triton_fa", [True, False])
55-
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
56-
@pytest.mark.skipif(
57-
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
58-
)
59-
def test_attention_fusion_v0(
60-
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
61-
):
62-
# Clean Dynamo cache to avoid reusing other test cases
63-
# (for some reason the reset at the end is not enough)
64-
torch._dynamo.reset()
65-
66-
# Use global backends
67-
global backend, backend_unfused
68-
69-
monkeypatch.setenv("VLLM_USE_V1", "1")
70-
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
71-
72-
# Prompt 4 seems too open-ended, differs between fused and unfused
73-
# (both outputs look reasonable though)
74-
prompts = example_prompts[:4] + example_prompts[5:]
75-
76-
compile_config = CompilationConfig(
77-
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
78-
# DYNAMO_ONCE does not properly propagate shapes.
79-
level=CompilationLevel.DYNAMO_AS_IS,
80-
backend="tests.compile.test_fusion_attn.backend_unfused",
81-
custom_ops=["+quant_fp8"],
82-
)
83-
vllm_config = VllmConfig(
84-
compilation_config=compile_config,
85-
model_config=ModelConfig(
86-
model=model,
87-
dtype=torch.bfloat16,
88-
),
89-
)
90-
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
91-
92-
llm = LLM(
93-
model,
94-
enforce_eager=True,
95-
compilation_config=compile_config,
96-
gpu_memory_utilization=0.5,
97-
max_model_len=2048,
98-
)
99-
100-
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
101-
102-
unfused_output = llm.generate(prompts, sampling_params)
103-
backend_unfused = None # Reset backend to make sure llm gets released
104-
del llm
105-
106-
compile_config = CompilationConfig(
107-
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
108-
# DYNAMO_ONCE does not properly propagate shapes.
109-
level=CompilationLevel.DYNAMO_AS_IS,
110-
backend="tests.compile.test_fusion_attn.backend",
111-
custom_ops=["+quant_fp8"],
112-
)
113-
vllm_config = VllmConfig(
114-
compilation_config=compile_config,
115-
model_config=ModelConfig(
116-
model=model,
117-
dtype=torch.bfloat16,
118-
),
119-
)
120-
121-
# AttnFusionPass needs attention layers to be registered in config upon init
122-
# so we initialize it during compilation.
123-
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
124-
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
125-
llm2 = LLM(
126-
model,
127-
enforce_eager=True,
128-
compilation_config=compile_config,
129-
gpu_memory_utilization=0.5,
130-
max_model_len=2048,
131-
)
132-
133-
# check support
134-
attn_fusion_supported = [
135-
layer.impl.fused_output_quant_supported(quant_key)
136-
for key, layer in compile_config.static_forward_context.items()
137-
]
138-
139-
print(f"{attn_fusion_supported=}")
140-
if any(attn_fusion_supported):
141-
# Check quant ops
142-
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
143-
144-
# attention ops present in both, just output_scale param changes
145-
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
146-
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
147-
assert len(attn_nodes_pre) == len(attn_nodes_post)
148-
149-
for i in range(len(attn_nodes_pre)):
150-
assert attn_nodes_pre[i].kwargs["output_scale"] is None
151-
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
152-
assert fused == attn_fusion_supported[i], (
153-
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
154-
)
155-
156-
# check outputs
157-
fused_output = llm2.generate(prompts, sampling_params)
158-
159-
# transform outputs to format expected by check_outputs_equal
160-
sample_outs = lambda s: (list(s.token_ids), s.text)
161-
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
162-
163-
check_outputs_equal(
164-
outputs_0_lst=outs_lst(unfused_output),
165-
outputs_1_lst=outs_lst(fused_output),
166-
name_0="unfused",
167-
name_1="fused",
168-
)
169-
170-
# Clean Dynamo cache to avoid polluting other case(s)
171-
torch._dynamo.reset()
172-
173-
# Reset backend to make sure llm2 gets released
174-
backend = None
175-
176-
17748
class AttentionQuantPatternModel(torch.nn.Module):
17849
"""Base model for AttentionQuantPattern fusion."""
17950

@@ -221,7 +92,7 @@ def __init__(
22192
device=self.device,
22293
)
22394

224-
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata:
95+
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
22596
"""Initialize attention metadata."""
22697

22798
# Create common attn metadata
@@ -232,30 +103,57 @@ def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetada
232103

233104
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
234105
num_blocks = batch_size * max_blocks
106+
backend = self.attn.backend
235107

236-
# Create dummy KV cache for FlashInfer TRTLLM
237-
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
238-
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
239-
kv_cache = torch.zeros(
240-
num_blocks,
241-
2,
242-
self.num_kv_heads,
243-
self.block_size,
244-
self.head_size,
245-
dtype=self.kv_cache_dtype,
246-
device=self.device,
247-
)
248-
if current_platform.is_rocm():
108+
# Create dummy KV cache for the selected backend
109+
if backend == _Backend.ROCM_ATTN:
249110
# k/v as 1st dimention
250-
if use_hnd:
251-
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
252-
else:
253-
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
254-
else:
111+
# HND: [num_blocks, num_kv_heads, block_size, head_size]
112+
kv_cache = torch.zeros(
113+
2,
114+
num_blocks,
115+
self.num_kv_heads,
116+
self.block_size,
117+
self.head_size,
118+
dtype=self.kv_cache_dtype,
119+
device=self.device,
120+
)
121+
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
122+
# k/v as 1st dimention
123+
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
124+
kv_cache = torch.zeros(
125+
2,
126+
num_blocks,
127+
self.block_size,
128+
self.num_kv_heads,
129+
self.head_size,
130+
dtype=self.kv_cache_dtype,
131+
device=self.device,
132+
)
133+
elif backend == _Backend.TRITON_ATTN:
255134
# k/v as 2nd dimention
256-
# Create kv_cache in HND layout and permute to NHD layout
257-
# (later will be permuted back to HND layout in forward pass)
258-
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
135+
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
136+
kv_cache = torch.zeros(
137+
num_blocks,
138+
2,
139+
self.num_kv_heads,
140+
self.block_size,
141+
self.head_size,
142+
dtype=self.kv_cache_dtype,
143+
device=self.device,
144+
)
145+
elif backend == _Backend.FLASHINFER:
146+
kv_cache = torch.zeros(
147+
num_blocks,
148+
2,
149+
self.num_kv_heads,
150+
self.block_size,
151+
self.head_size,
152+
dtype=self.kv_cache_dtype,
153+
device=self.device,
154+
).permute(0, 1, 3, 2, 4)
155+
else:
156+
raise ValueError(f"Unsupported backend: {backend}")
259157
self.attn.kv_cache = [kv_cache]
260158

261159
# Build attn metadata
@@ -375,10 +273,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
375273
@pytest.mark.parametrize("model_name, model_class", MODELS)
376274
@pytest.mark.parametrize(
377275
"backend",
378-
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN],
379-
)
380-
@pytest.mark.parametrize(
381-
"split_attention", [False, True] if current_platform.is_rocm() else [False]
276+
[_Backend.FLASHINFER]
277+
if current_platform.is_cuda()
278+
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
382279
)
383280
# TODO(boyuan): test inductor graph partition on rocm
384281
@pytest.mark.parametrize(
@@ -405,7 +302,6 @@ def test_attention_quant_pattern(
405302
model_name: str,
406303
model_class: type[AttentionQuantPatternModel],
407304
backend: _Backend,
408-
split_attention: bool,
409305
use_inductor_graph_partition: bool,
410306
monkeypatch,
411307
dist_init,
@@ -417,8 +313,6 @@ def test_attention_quant_pattern(
417313
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
418314

419315
monkeypatch.setenv("VLLM_USE_V1", "1")
420-
if split_attention:
421-
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
422316

423317
device = torch.device("cuda:0")
424318
torch.manual_seed(42)
@@ -466,9 +360,7 @@ def test_attention_quant_pattern(
466360
model_unfused = model_unfused.to(device)
467361

468362
forward_ctx = get_forward_context()
469-
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
470-
batch_size, use_hnd=split_attention
471-
)
363+
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
472364

473365
# Run model directly without compilation and fusion
474366
result_unfused = model_unfused(q, k, v)
@@ -494,9 +386,7 @@ def test_attention_quant_pattern(
494386
model_fused = model_fused.to(device)
495387

496388
forward_ctx = get_forward_context()
497-
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
498-
batch_size, use_hnd=split_attention
499-
)
389+
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
500390

501391
# Create test backend with fusion passes enabled
502392
noop_pass = NoOpEliminationPass(vllm_config)

vllm/attention/backends/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ class _Backend(enum.Enum):
2525
FLEX_ATTENTION = enum.auto()
2626
TREE_ATTN = enum.auto()
2727
ROCM_ATTN = enum.auto()
28+
ROCM_AITER_UNIFIED_ATTN = enum.auto()

vllm/attention/selector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,4 @@ def global_force_attn_backend_context_manager(
254254
finally:
255255
# Revert the original global backend override, if any
256256
global_force_attn_backend(original_value)
257+
_cached_get_attn_backend.cache_clear()

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
16231623
"TREE_ATTN",
16241624
"XFORMERS",
16251625
"ROCM_ATTN",
1626+
"ROCM_AITER_UNIFIED_ATTN",
16261627
]
16271628
if (
16281629
envs.is_set("VLLM_ATTENTION_BACKEND")

vllm/envs.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
LD_LIBRARY_PATH: Optional[str] = None
1919
VLLM_USE_TRITON_FLASH_ATTN: bool = True
2020
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
21-
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
2221
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
2322
LOCAL_RANK: int = 0
2423
CUDA_VISIBLE_DEVICES: Optional[str] = None
@@ -109,6 +108,7 @@
109108
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
110109
VLLM_ROCM_USE_TRITON_ROPE: bool = False
111110
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
111+
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
112112
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
113113
VLLM_ROCM_FP8_PADDING: bool = True
114114
VLLM_ROCM_MOE_PADDING: bool = True
@@ -475,10 +475,6 @@ def get_vllm_port() -> Optional[int]:
475475
os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower()
476476
in ("true", "1")
477477
),
478-
# Use AITER triton unified attention for V1 attention
479-
"VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: (
480-
os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1")
481-
),
482478
# Force vllm to use a specific flash-attention version (2 or 3), only valid
483479
# when using the flash-attention backend.
484480
"VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int(
@@ -896,6 +892,11 @@ def get_vllm_port() -> Optional[int]:
896892
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
897893
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
898894
),
895+
# Use AITER triton unified attention for V1 attention
896+
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
897+
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
898+
in ("true", "1")
899+
),
899900
# use rocm skinny gemms
900901
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
901902
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
@@ -1434,7 +1435,6 @@ def compute_hash() -> str:
14341435
"VLLM_FUSED_MOE_CHUNK_SIZE",
14351436
"VLLM_FLASHINFER_MOE_BACKEND",
14361437
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
1437-
"VLLM_USE_AITER_UNIFIED_ATTENTION",
14381438
"VLLM_ATTENTION_BACKEND",
14391439
"VLLM_USE_FLASHINFER_SAMPLER",
14401440
"VLLM_DISABLED_KERNELS",
@@ -1462,6 +1462,7 @@ def compute_hash() -> str:
14621462
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
14631463
"VLLM_ROCM_USE_TRITON_ROPE",
14641464
"VLLM_ROCM_USE_AITER_FP8BMM",
1465+
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
14651466
"VLLM_ROCM_USE_SKINNY_GEMM",
14661467
"VLLM_ROCM_FP8_PADDING",
14671468
"VLLM_ROCM_MOE_PADDING",

0 commit comments

Comments
 (0)