7
7
import torch ._dynamo
8
8
9
9
from tests .compile .backend import LazyInitPass , TestBackend
10
- from tests .models .utils import check_outputs_equal
11
10
from tests .v1 .attention .utils import BatchSpec , create_common_attn_metadata
12
- from vllm import LLM , SamplingParams
13
11
from vllm ._custom_ops import cutlass_scaled_fp4_mm , scaled_fp4_quant
14
12
from vllm .attention import Attention , AttentionMetadata
15
13
from vllm .attention .backends .registry import _Backend
31
29
)
32
30
from vllm .forward_context import get_forward_context , set_forward_context
33
31
from vllm .model_executor .layers .quantization .utils .quant_utils import (
34
- QuantKey ,
35
32
kFp8StaticTensorSym ,
36
33
kNvfp4Quant ,
37
34
)
48
45
backend_unfused : Optional [TestBackend ] = None
49
46
50
47
51
- @pytest .mark .parametrize (
52
- "model, quant_key" , [("amd/Llama-3.1-8B-Instruct-FP8-KV" , kFp8StaticTensorSym )]
53
- )
54
- @pytest .mark .parametrize ("use_triton_fa" , [True , False ])
55
- @pytest .mark .skipif (not current_platform .supports_fp8 (), reason = "Need FP8" )
56
- @pytest .mark .skipif (
57
- not current_platform .is_rocm (), reason = "V0 attn quant fusion only on ROCm"
58
- )
59
- def test_attention_fusion_v0 (
60
- example_prompts , monkeypatch , model : str , quant_key : QuantKey , use_triton_fa : bool
61
- ):
62
- # Clean Dynamo cache to avoid reusing other test cases
63
- # (for some reason the reset at the end is not enough)
64
- torch ._dynamo .reset ()
65
-
66
- # Use global backends
67
- global backend , backend_unfused
68
-
69
- monkeypatch .setenv ("VLLM_USE_V1" , "1" )
70
- monkeypatch .setenv ("VLLM_USE_TRITON_FLASH_ATTN" , str (int (use_triton_fa )))
71
-
72
- # Prompt 4 seems too open-ended, differs between fused and unfused
73
- # (both outputs look reasonable though)
74
- prompts = example_prompts [:4 ] + example_prompts [5 :]
75
-
76
- compile_config = CompilationConfig (
77
- # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
78
- # DYNAMO_ONCE does not properly propagate shapes.
79
- level = CompilationLevel .DYNAMO_AS_IS ,
80
- backend = "tests.compile.test_fusion_attn.backend_unfused" ,
81
- custom_ops = ["+quant_fp8" ],
82
- )
83
- vllm_config = VllmConfig (
84
- compilation_config = compile_config ,
85
- model_config = ModelConfig (
86
- model = model ,
87
- dtype = torch .bfloat16 ,
88
- ),
89
- )
90
- backend_unfused = TestBackend (NoOpEliminationPass (vllm_config ))
91
-
92
- llm = LLM (
93
- model ,
94
- enforce_eager = True ,
95
- compilation_config = compile_config ,
96
- gpu_memory_utilization = 0.5 ,
97
- max_model_len = 2048 ,
98
- )
99
-
100
- sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 10 , top_p = 0.95 )
101
-
102
- unfused_output = llm .generate (prompts , sampling_params )
103
- backend_unfused = None # Reset backend to make sure llm gets released
104
- del llm
105
-
106
- compile_config = CompilationConfig (
107
- # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
108
- # DYNAMO_ONCE does not properly propagate shapes.
109
- level = CompilationLevel .DYNAMO_AS_IS ,
110
- backend = "tests.compile.test_fusion_attn.backend" ,
111
- custom_ops = ["+quant_fp8" ],
112
- )
113
- vllm_config = VllmConfig (
114
- compilation_config = compile_config ,
115
- model_config = ModelConfig (
116
- model = model ,
117
- dtype = torch .bfloat16 ,
118
- ),
119
- )
120
-
121
- # AttnFusionPass needs attention layers to be registered in config upon init
122
- # so we initialize it during compilation.
123
- attn_pass = LazyInitPass (AttnFusionPass , vllm_config )
124
- backend = TestBackend (NoOpEliminationPass (vllm_config ), attn_pass )
125
- llm2 = LLM (
126
- model ,
127
- enforce_eager = True ,
128
- compilation_config = compile_config ,
129
- gpu_memory_utilization = 0.5 ,
130
- max_model_len = 2048 ,
131
- )
132
-
133
- # check support
134
- attn_fusion_supported = [
135
- layer .impl .fused_output_quant_supported (quant_key )
136
- for key , layer in compile_config .static_forward_context .items ()
137
- ]
138
-
139
- print (f"{ attn_fusion_supported = } " )
140
- if any (attn_fusion_supported ):
141
- # Check quant ops
142
- backend .check_before_ops ([QUANT_OPS [quant_key ]], fully_replaced = False )
143
-
144
- # attention ops present in both, just output_scale param changes
145
- attn_nodes_pre = list (find_op_nodes (ATTN_OP , backend .graph_pre_pass ))
146
- attn_nodes_post = list (find_op_nodes (ATTN_OP , backend .graph_post_pass ))
147
- assert len (attn_nodes_pre ) == len (attn_nodes_post )
148
-
149
- for i in range (len (attn_nodes_pre )):
150
- assert attn_nodes_pre [i ].kwargs ["output_scale" ] is None
151
- fused = attn_nodes_post [i ].kwargs ["output_scale" ] is not None
152
- assert fused == attn_fusion_supported [i ], (
153
- f"Node { i } { '' if fused else 'not ' } expected to have fused output quant"
154
- )
155
-
156
- # check outputs
157
- fused_output = llm2 .generate (prompts , sampling_params )
158
-
159
- # transform outputs to format expected by check_outputs_equal
160
- sample_outs = lambda s : (list (s .token_ids ), s .text )
161
- outs_lst = lambda ros : [sample_outs (ro .outputs [0 ]) for ro in ros ]
162
-
163
- check_outputs_equal (
164
- outputs_0_lst = outs_lst (unfused_output ),
165
- outputs_1_lst = outs_lst (fused_output ),
166
- name_0 = "unfused" ,
167
- name_1 = "fused" ,
168
- )
169
-
170
- # Clean Dynamo cache to avoid polluting other case(s)
171
- torch ._dynamo .reset ()
172
-
173
- # Reset backend to make sure llm2 gets released
174
- backend = None
175
-
176
-
177
48
class AttentionQuantPatternModel (torch .nn .Module ):
178
49
"""Base model for AttentionQuantPattern fusion."""
179
50
@@ -221,7 +92,7 @@ def __init__(
221
92
device = self .device ,
222
93
)
223
94
224
- def build_attn_metadata (self , batch_size : int , use_hnd : bool ) -> AttentionMetadata :
95
+ def build_attn_metadata (self , batch_size : int ) -> AttentionMetadata :
225
96
"""Initialize attention metadata."""
226
97
227
98
# Create common attn metadata
@@ -232,30 +103,57 @@ def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetada
232
103
233
104
max_blocks = (max (batch_spec .seq_lens ) + self .block_size - 1 ) // self .block_size
234
105
num_blocks = batch_size * max_blocks
106
+ backend = self .attn .backend
235
107
236
- # Create dummy KV cache for FlashInfer TRTLLM
237
- # - NHD: [num_blocks, block_size, num_kv_heads, head_size]
238
- # - HND: [num_blocks, num_kv_heads, block_size, head_size]
239
- kv_cache = torch .zeros (
240
- num_blocks ,
241
- 2 ,
242
- self .num_kv_heads ,
243
- self .block_size ,
244
- self .head_size ,
245
- dtype = self .kv_cache_dtype ,
246
- device = self .device ,
247
- )
248
- if current_platform .is_rocm ():
108
+ # Create dummy KV cache for the selected backend
109
+ if backend == _Backend .ROCM_ATTN :
249
110
# k/v as 1st dimention
250
- if use_hnd :
251
- kv_cache = kv_cache .permute (1 , 0 , 2 , 3 , 4 )
252
- else :
253
- kv_cache = kv_cache .permute (1 , 0 , 3 , 2 , 4 )
254
- else :
111
+ # HND: [num_blocks, num_kv_heads, block_size, head_size]
112
+ kv_cache = torch .zeros (
113
+ 2 ,
114
+ num_blocks ,
115
+ self .num_kv_heads ,
116
+ self .block_size ,
117
+ self .head_size ,
118
+ dtype = self .kv_cache_dtype ,
119
+ device = self .device ,
120
+ )
121
+ elif backend == _Backend .ROCM_AITER_UNIFIED_ATTN :
122
+ # k/v as 1st dimention
123
+ # NHD: [num_blocks, block_size, num_kv_heads, head_size]
124
+ kv_cache = torch .zeros (
125
+ 2 ,
126
+ num_blocks ,
127
+ self .block_size ,
128
+ self .num_kv_heads ,
129
+ self .head_size ,
130
+ dtype = self .kv_cache_dtype ,
131
+ device = self .device ,
132
+ )
133
+ elif backend == _Backend .TRITON_ATTN :
255
134
# k/v as 2nd dimention
256
- # Create kv_cache in HND layout and permute to NHD layout
257
- # (later will be permuted back to HND layout in forward pass)
258
- kv_cache = kv_cache .permute (0 , 1 , 3 , 2 , 4 )
135
+ # NHD: [num_blocks, block_size, num_kv_heads, head_size]
136
+ kv_cache = torch .zeros (
137
+ num_blocks ,
138
+ 2 ,
139
+ self .num_kv_heads ,
140
+ self .block_size ,
141
+ self .head_size ,
142
+ dtype = self .kv_cache_dtype ,
143
+ device = self .device ,
144
+ )
145
+ elif backend == _Backend .FLASHINFER :
146
+ kv_cache = torch .zeros (
147
+ num_blocks ,
148
+ 2 ,
149
+ self .num_kv_heads ,
150
+ self .block_size ,
151
+ self .head_size ,
152
+ dtype = self .kv_cache_dtype ,
153
+ device = self .device ,
154
+ ).permute (0 , 1 , 3 , 2 , 4 )
155
+ else :
156
+ raise ValueError (f"Unsupported backend: { backend } " )
259
157
self .attn .kv_cache = [kv_cache ]
260
158
261
159
# Build attn metadata
@@ -375,10 +273,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
375
273
@pytest .mark .parametrize ("model_name, model_class" , MODELS )
376
274
@pytest .mark .parametrize (
377
275
"backend" ,
378
- [_Backend .FLASHINFER ] if current_platform .is_cuda () else [_Backend .TRITON_ATTN ],
379
- )
380
- @pytest .mark .parametrize (
381
- "split_attention" , [False , True ] if current_platform .is_rocm () else [False ]
276
+ [_Backend .FLASHINFER ]
277
+ if current_platform .is_cuda ()
278
+ else [_Backend .ROCM_AITER_UNIFIED_ATTN , _Backend .ROCM_ATTN , _Backend .TRITON_ATTN ],
382
279
)
383
280
# TODO(boyuan): test inductor graph partition on rocm
384
281
@pytest .mark .parametrize (
@@ -405,7 +302,6 @@ def test_attention_quant_pattern(
405
302
model_name : str ,
406
303
model_class : type [AttentionQuantPatternModel ],
407
304
backend : _Backend ,
408
- split_attention : bool ,
409
305
use_inductor_graph_partition : bool ,
410
306
monkeypatch ,
411
307
dist_init ,
@@ -417,8 +313,6 @@ def test_attention_quant_pattern(
417
313
pytest .skip ("inductor graph partition is only available in PyTorch 2.9+" )
418
314
419
315
monkeypatch .setenv ("VLLM_USE_V1" , "1" )
420
- if split_attention :
421
- monkeypatch .setenv ("VLLM_V1_USE_PREFILL_DECODE_ATTENTION" , "1" )
422
316
423
317
device = torch .device ("cuda:0" )
424
318
torch .manual_seed (42 )
@@ -466,9 +360,7 @@ def test_attention_quant_pattern(
466
360
model_unfused = model_unfused .to (device )
467
361
468
362
forward_ctx = get_forward_context ()
469
- forward_ctx .attn_metadata = model_unfused .build_attn_metadata (
470
- batch_size , use_hnd = split_attention
471
- )
363
+ forward_ctx .attn_metadata = model_unfused .build_attn_metadata (batch_size )
472
364
473
365
# Run model directly without compilation and fusion
474
366
result_unfused = model_unfused (q , k , v )
@@ -494,9 +386,7 @@ def test_attention_quant_pattern(
494
386
model_fused = model_fused .to (device )
495
387
496
388
forward_ctx = get_forward_context ()
497
- forward_ctx .attn_metadata = model_fused .build_attn_metadata (
498
- batch_size , use_hnd = split_attention
499
- )
389
+ forward_ctx .attn_metadata = model_fused .build_attn_metadata (batch_size )
500
390
501
391
# Create test backend with fusion passes enabled
502
392
noop_pass = NoOpEliminationPass (vllm_config )
0 commit comments