Skip to content

Commit 24d0c9e

Browse files
[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)
Signed-off-by: elvischenv <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent cc7ae5e commit 24d0c9e

27 files changed

+598
-202
lines changed

benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
import flashinfer
1010
import torch
1111

12+
from vllm.utils import round_up
13+
1214
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
1315
FP8_DTYPE = torch.float8_e4m3fn
16+
FP4_DTYPE = torch.uint8
1417

1518

1619
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -61,28 +64,27 @@ def benchmark_decode(
6164
else:
6265
raise ValueError(f"Invalid kv_layout: {kv_layout}")
6366

64-
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
67+
# Always using 1.0 scale to reflect the real perf in benchmarking
68+
q_scale = 1.0
69+
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
6570
if q_quant_dtype == FP8_DTYPE:
66-
query, q_scale = to_float8(query)
67-
ref_query = query.to(dtype) * q_scale
71+
query, _ = to_float8(ref_query)
6872
else:
69-
q_scale = 1.0
70-
ref_query = query
73+
query = ref_query
7174

7275
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
7376
kv_lens[-1] = max_seq_len
7477

7578
seq_lens = kv_lens
7679
max_seq_len = torch.max(seq_lens).item()
7780

78-
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
81+
# Always using 1.0 scale to reflect the real perf in benchmarking
82+
k_scale = v_scale = 1.0
83+
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
7984
if kv_quant_dtype == FP8_DTYPE:
80-
kv_cache, kv_scale = to_float8(kv_cache)
81-
ref_kv_cache = kv_cache.to(dtype) * kv_scale
85+
kv_cache, _ = to_float8(ref_kv_cache)
8286
else:
83-
kv_scale = 1.0
84-
ref_kv_cache = kv_cache
85-
k_scale = v_scale = kv_scale
87+
kv_cache = ref_kv_cache
8688

8789
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
8890
block_tables = torch.randint(
@@ -142,11 +144,31 @@ def time_fn(fn, warmup=10, trials=20):
142144
return sum(times) / len(times), torch.std(torch.tensor(times))
143145

144146
o_scale = 1.0
147+
o_sf_scale = None
145148
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
146-
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
149+
if o_quant_dtype == FP4_DTYPE:
150+
o_sf_scale = 500.0
151+
output_trtllm = flashinfer.utils.FP4Tensor(
152+
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
153+
torch.empty(
154+
(
155+
round_up(query.shape[0], 128),
156+
round_up(query.shape[1] * query.shape[2] // 16, 4),
157+
),
158+
dtype=torch.float8_e4m3fn,
159+
),
160+
)
161+
else:
162+
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
147163

148164
def baseline_decode():
149-
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
165+
return wrapper.run(
166+
ref_query,
167+
ref_kv_cache,
168+
k_scale=k_scale,
169+
v_scale=v_scale,
170+
out=output_baseline,
171+
)
150172

151173
def trtllm_decode():
152174
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
@@ -158,6 +180,7 @@ def trtllm_decode():
158180
max_seq_len=max_seq_len,
159181
bmm1_scale=q_scale * k_scale * sm_scale,
160182
bmm2_scale=v_scale / o_scale,
183+
o_sf_scale=o_sf_scale,
161184
out=output_trtllm,
162185
)
163186

@@ -237,6 +260,7 @@ def write_results_to_csv(results, filename=None):
237260
(None, None, None),
238261
(None, FP8_DTYPE, None),
239262
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
263+
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
240264
]
241265

242266
for quant_dtype in quant_dtypes:

benchmarks/kernels/benchmark_trtllm_prefill_attention.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
import flashinfer
1010
import torch
1111

12+
from vllm.utils import round_up
13+
1214
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
1315
FP8_DTYPE = torch.float8_e4m3fn
16+
FP4_DTYPE = torch.uint8
1417

1518

1619
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -72,28 +75,29 @@ def benchmark_prefill(
7275
]
7376
)
7477

75-
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
78+
# Always using 1.0 scale to reflect the real perf in benchmarking
79+
q_scale = 1.0
80+
ref_query = torch.randn(
81+
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
82+
)
7683
if q_quant_dtype == FP8_DTYPE:
77-
query, q_scale = to_float8(query)
78-
ref_query = query.to(dtype) * q_scale
84+
query, _ = to_float8(ref_query)
7985
else:
80-
q_scale = 1.0
81-
ref_query = query
86+
query = ref_query
8287

8388
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
8489
kv_lens[-1] = max_kv_len
8590

8691
seq_lens = kv_lens + q_lens
8792
max_seq_len = torch.max(seq_lens).item()
8893

89-
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
94+
# Always using 1.0 scale to reflect the real perf in benchmarking
95+
k_scale = v_scale = 1.0
96+
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
9097
if kv_quant_dtype == FP8_DTYPE:
91-
kv_cache, kv_scale = to_float8(kv_cache)
92-
ref_kv_cache = kv_cache.to(dtype) * kv_scale
98+
kv_cache, _ = to_float8(ref_kv_cache)
9399
else:
94-
kv_scale = 1.0
95-
ref_kv_cache = kv_cache
96-
k_scale = v_scale = kv_scale
100+
kv_cache = ref_kv_cache
97101

98102
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
99103
block_tables = torch.randint(
@@ -152,11 +156,31 @@ def time_fn(fn, warmup=10, trials=20):
152156
return sum(times) / len(times), torch.std(torch.tensor(times))
153157

154158
o_scale = 1.0
159+
o_sf_scale = None
155160
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
156-
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
161+
if o_quant_dtype == FP4_DTYPE:
162+
o_sf_scale = 500.0
163+
output_trtllm = flashinfer.utils.FP4Tensor(
164+
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
165+
torch.empty(
166+
(
167+
round_up(query.shape[0], 128),
168+
round_up(query.shape[1] * query.shape[2] // 16, 4),
169+
),
170+
dtype=torch.float8_e4m3fn,
171+
),
172+
)
173+
else:
174+
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
157175

158176
def baseline_prefill():
159-
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
177+
return wrapper.run(
178+
ref_query,
179+
ref_kv_cache,
180+
k_scale=k_scale,
181+
v_scale=v_scale,
182+
out=output_baseline,
183+
)
160184

161185
def trtllm_prefill():
162186
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
@@ -172,6 +196,7 @@ def trtllm_prefill():
172196
batch_size=batch_size,
173197
cum_seq_lens_q=q_indptr,
174198
cum_seq_lens_kv=kv_indptr,
199+
o_sf_scale=o_sf_scale,
175200
out=output_trtllm,
176201
)
177202

@@ -250,6 +275,7 @@ def write_results_to_csv(results, filename=None):
250275
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
251276
(None, None, None),
252277
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
278+
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
253279
]
254280

255281
for quant_dtype in quant_dtypes:

tests/compile/test_functionalization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from vllm import LLM, SamplingParams
99
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
1010
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
11-
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
12-
kFp8DynamicTokenSym, kFp8StaticTensorSym)
11+
from vllm.compilation.fusion import FUSED_OPS, FusionPass
1312
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1413
from vllm.compilation.noop_elimination import NoOpEliminationPass
1514
from vllm.config import CompilationConfig, PassConfig, VllmConfig
15+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
16+
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
1617

1718
from .backend import TestBackend
1819

tests/compile/test_fusion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import vllm.envs as envs
88
import vllm.plugins
99
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
10-
FusionPass, GroupShape, QuantKey)
10+
FusionPass)
1111
from vllm.compilation.noop_elimination import NoOpEliminationPass
1212
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
1313
VllmConfig)
1414
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
16+
GroupShape, QuantKey, ScaleDesc)
1517
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1618
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
1719
from vllm.platforms import current_platform
@@ -30,10 +32,8 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
3032
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
3133
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
3234
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
33-
self.key = QuantKey(dtype=FP8_DTYPE,
34-
static=static,
35-
group_shape=group_shape,
36-
symmetric=True)
35+
quant_scale = ScaleDesc(torch.float32, static, group_shape)
36+
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
3737
if static:
3838
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
3939
else:

0 commit comments

Comments
 (0)