Skip to content

Commit 83156c7

Browse files
authored
[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel (#22095)
Signed-off-by: elvischenv <[email protected]>
1 parent 4771df7 commit 83156c7

File tree

9 files changed

+701
-235
lines changed

9 files changed

+701
-235
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ steps:
664664
# Attention
665665
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
666666
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
667-
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
667+
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
668668
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
669669
# Quantization
670670
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'

benchmarks/kernels/benchmark_trtllm_attention.py renamed to benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def benchmark_decode(
4141
device = "cuda"
4242
torch.manual_seed(0)
4343

44-
# Currently only HEAD_GRP_SIZE == 8 is supported
4544
HEAD_GRP_SIZE = 8
4645
MAX_SEQ_LEN = max_seq_len
4746

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import csv
5+
import os
6+
import random
7+
from datetime import datetime
8+
9+
import flashinfer
10+
import torch
11+
12+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13+
14+
# KV Cache Layout for TRT-LLM
15+
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
16+
17+
18+
def to_float8(x, dtype=torch.float8_e4m3fn):
19+
finfo = torch.finfo(dtype)
20+
min_val, max_val = x.aminmax()
21+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
22+
scale = finfo.max / amax * 0.1
23+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
24+
return x_scl_sat.to(dtype), scale.float().reciprocal()
25+
26+
27+
@torch.no_grad()
28+
def benchmark_prefill(
29+
num_seqs,
30+
max_seq_len,
31+
page_size=16,
32+
dtype=torch.bfloat16,
33+
kv_layout="HND",
34+
num_kv_heads=8,
35+
kv_cache_dtype="auto",
36+
head_dim=128,
37+
warmup=10,
38+
trials=20,
39+
):
40+
torch.set_default_device("cuda")
41+
torch.manual_seed(0)
42+
43+
HEAD_GRP_SIZE = 8
44+
MAX_SEQ_LEN = max_seq_len
45+
46+
# large number to reduce kv_cache reuse
47+
NUM_BLOCKS = int(256000 / page_size)
48+
49+
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
50+
51+
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
52+
sm_scale = float(1.0 / (head_dim**0.5))
53+
54+
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
55+
q_lens[-1] = MAX_SEQ_LEN
56+
max_q_len = max(q_lens)
57+
q_indptr = torch.cat(
58+
[
59+
torch.tensor([0], dtype=torch.int32),
60+
torch.cumsum(
61+
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
62+
),
63+
]
64+
)
65+
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
66+
67+
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
68+
kv_lens[-1] = MAX_SEQ_LEN
69+
70+
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
71+
max_seq_len = max(seq_lens)
72+
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
73+
74+
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
75+
block_tables = torch.randint(
76+
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
77+
)
78+
79+
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
80+
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
81+
k_scale = v_scale = 1.0
82+
83+
if kv_cache_dtype.startswith("fp8"):
84+
kv_cache, _ = to_float8(kv_cache)
85+
86+
output_trtllm = torch.empty(q.shape, dtype=dtype)
87+
88+
kv_indptr = [0]
89+
kv_indices = []
90+
kv_last_page_lens = []
91+
for i in range(num_seqs):
92+
seq_len = seq_lens[i]
93+
assert seq_len > 0
94+
num_blocks = (seq_len + page_size - 1) // page_size
95+
kv_indices.extend(block_tables[i, :num_blocks])
96+
kv_indptr.append(kv_indptr[-1] + num_blocks)
97+
kv_last_page_len = seq_len % page_size
98+
if kv_last_page_len == 0:
99+
kv_last_page_len = page_size
100+
kv_last_page_lens.append(kv_last_page_len)
101+
102+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
103+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
104+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
105+
106+
output_baseline = torch.empty(q.shape, dtype=dtype)
107+
108+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
109+
workspace_buffer, kv_layout
110+
)
111+
wrapper.plan(
112+
q_indptr,
113+
kv_indptr,
114+
kv_indices,
115+
kv_last_page_lens,
116+
num_qo_heads,
117+
num_kv_heads,
118+
head_dim,
119+
page_size,
120+
causal=True,
121+
sm_scale=sm_scale,
122+
q_data_type=dtype,
123+
kv_data_type=kv_cache.dtype,
124+
)
125+
126+
def time_fn(fn, warmup=10, trials=20):
127+
torch.cuda.synchronize()
128+
start = torch.cuda.Event(enable_timing=True)
129+
end = torch.cuda.Event(enable_timing=True)
130+
times = []
131+
for i in range(warmup):
132+
fn()
133+
for i in range(trials):
134+
start.record()
135+
fn()
136+
end.record()
137+
torch.cuda.synchronize()
138+
times.append(start.elapsed_time(end)) # ms
139+
return sum(times) / len(times), torch.std(torch.tensor(times))
140+
141+
def baseline_prefill():
142+
return wrapper.run(
143+
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
144+
)
145+
146+
def trt_prefill():
147+
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
148+
query=q,
149+
kv_cache=kv_cache,
150+
workspace_buffer=workspace_buffer,
151+
block_tables=block_tables,
152+
seq_lens=seq_lens_tensor,
153+
max_q_len=max_q_len,
154+
max_kv_len=max_seq_len,
155+
bmm1_scale=k_scale * sm_scale,
156+
bmm2_scale=v_scale,
157+
batch_size=num_seqs,
158+
cum_seq_lens_q=q_indptr,
159+
cum_seq_lens_kv=kv_indptr,
160+
out=output_trtllm,
161+
)
162+
163+
trt_mean, trt_std = time_fn(trt_prefill)
164+
baseline_mean, baseline_std = time_fn(baseline_prefill)
165+
166+
# Calculate percentage speedup (positive means TRT is faster)
167+
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
168+
169+
print(
170+
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
171+
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
172+
)
173+
174+
# Return results for CSV writing
175+
return {
176+
"num_seqs": num_seqs,
177+
"trt_mean": trt_mean,
178+
"trt_std": trt_std.item(),
179+
"baseline_mean": baseline_mean,
180+
"baseline_std": baseline_std.item(),
181+
"speedup_percent": speedup_percent,
182+
"q_dtype": str(dtype),
183+
"kv_cache_dtype": kv_cache_dtype,
184+
"page_size": page_size,
185+
"num_kv_heads": num_kv_heads,
186+
"head_dim": head_dim,
187+
"max_seq_len": max_seq_len,
188+
}
189+
190+
191+
def write_results_to_csv(results, filename=None):
192+
"""Write benchmark results to CSV file."""
193+
if filename is None:
194+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
195+
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
196+
197+
fieldnames = [
198+
"num_seqs",
199+
"trt_mean",
200+
"trt_std",
201+
"baseline_mean",
202+
"baseline_std",
203+
"speedup_percent",
204+
"q_dtype",
205+
"kv_cache_dtype",
206+
"page_size",
207+
"num_kv_heads",
208+
"head_dim",
209+
"max_seq_len",
210+
]
211+
212+
file_exists = os.path.exists(filename)
213+
214+
with open(filename, "a", newline="") as csvfile:
215+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
216+
217+
if not file_exists:
218+
writer.writeheader()
219+
220+
for result in results:
221+
writer.writerow(result)
222+
223+
print(f"Results written to {filename}")
224+
225+
226+
if __name__ == "__main__":
227+
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
228+
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
229+
all_results = []
230+
231+
print(
232+
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
233+
"output_dtype: bfloat16"
234+
)
235+
print(
236+
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
237+
"baseline_std\tspeedup_percent"
238+
)
239+
for max_seq_len in max_seq_lens:
240+
for bs in num_seqs:
241+
result = benchmark_prefill(
242+
bs,
243+
max_seq_len,
244+
dtype=torch.bfloat16,
245+
kv_cache_dtype="auto",
246+
)
247+
all_results.append(result)
248+
249+
# Write all results to CSV
250+
write_results_to_csv(all_results)

0 commit comments

Comments
 (0)