Skip to content

Commit 69ec3ca

Browse files
[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer (#6051)
Co-authored-by: Simon Mo <[email protected]>
1 parent 81d7a50 commit 69ec3ca

File tree

6 files changed

+279
-20
lines changed

6 files changed

+279
-20
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,15 @@ steps:
118118

119119
- label: Kernels Test %N
120120
#mirror_hardwares: [amd]
121-
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
121+
commands:
122+
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
123+
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
122124
parallelism: 4
123125

124126
- label: Models Test
125127
#mirror_hardwares: [amd]
126128
commands:
129+
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
127130
- pytest -v -s models -m \"not vlm\"
128131

129132
- label: Vision Language Models Test
@@ -234,7 +237,7 @@ steps:
234237
- pytest -v -s distributed/test_custom_all_reduce.py
235238
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
236239
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
237-
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
240+
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
238241
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
239242
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
240243
- pytest -v -s -x lora/test_mixtral.py

tests/kernels/test_flashinfer.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from typing import List, Optional, Tuple
2+
3+
import flashinfer
4+
import pytest
5+
import torch
6+
7+
NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
8+
HEAD_SIZES = [128, 256]
9+
BLOCK_SIZES = [16, 32]
10+
DTYPES = [torch.float16, torch.bfloat16]
11+
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
12+
13+
14+
def ref_paged_attn(
15+
query: torch.Tensor,
16+
key_cache: torch.Tensor,
17+
value_cache: torch.Tensor,
18+
query_lens: List[int],
19+
kv_lens: List[int],
20+
block_tables: torch.Tensor,
21+
scale: float,
22+
sliding_window: Optional[int] = None,
23+
soft_cap: Optional[float] = None,
24+
) -> torch.Tensor:
25+
num_seqs = len(query_lens)
26+
block_tables = block_tables.cpu().numpy()
27+
_, block_size, num_kv_heads, head_size = key_cache.shape
28+
29+
outputs: List[torch.Tensor] = []
30+
start_idx = 0
31+
for i in range(num_seqs):
32+
query_len = query_lens[i]
33+
kv_len = kv_lens[i]
34+
q = query[start_idx:start_idx + query_len]
35+
q *= scale
36+
37+
num_kv_blocks = (kv_len + block_size - 1) // block_size
38+
block_indices = block_tables[i, :num_kv_blocks]
39+
40+
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
41+
k = k[:kv_len]
42+
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
43+
v = v[:kv_len]
44+
45+
if q.shape[1] != k.shape[1]:
46+
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
47+
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
48+
attn = torch.einsum("qhd,khd->hqk", q, k).float()
49+
empty_mask = torch.ones(query_len, kv_len)
50+
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
51+
if sliding_window is not None:
52+
sliding_window_mask = torch.triu(empty_mask,
53+
diagonal=kv_len -
54+
(query_len + sliding_window) +
55+
1).bool().logical_not()
56+
mask |= sliding_window_mask
57+
if soft_cap is not None:
58+
attn = soft_cap * torch.tanh(attn / soft_cap)
59+
attn.masked_fill_(mask, float("-inf"))
60+
attn = torch.softmax(attn, dim=-1).to(v.dtype)
61+
out = torch.einsum("hqk,khd->qhd", attn, v)
62+
63+
outputs.append(out)
64+
start_idx += query_len
65+
66+
return torch.cat(outputs, dim=0)
67+
68+
69+
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
70+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
71+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
72+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
73+
@pytest.mark.parametrize("dtype", DTYPES)
74+
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
75+
@torch.inference_mode
76+
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
77+
num_heads: Tuple[int,
78+
int], head_size: int,
79+
dtype: torch.dtype, block_size: int,
80+
soft_cap: Optional[float]) -> None:
81+
torch.set_default_device("cuda")
82+
torch.cuda.manual_seed_all(0)
83+
num_seqs = len(kv_lens)
84+
num_query_heads = num_heads[0]
85+
num_kv_heads = num_heads[1]
86+
assert num_query_heads % num_kv_heads == 0
87+
max_kv_len = max(kv_lens)
88+
scale = head_size**-0.5
89+
90+
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
91+
key_value_cache = torch.randn(NUM_BLOCKS,
92+
2,
93+
block_size,
94+
num_kv_heads,
95+
head_size,
96+
dtype=dtype)
97+
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
98+
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
99+
100+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
101+
block_tables = torch.randint(0,
102+
NUM_BLOCKS,
103+
(num_seqs, max_num_blocks_per_seq),
104+
dtype=torch.int32)
105+
106+
kv_indptr = [0]
107+
kv_indices = []
108+
kv_last_page_lens = []
109+
for i in range(num_seqs):
110+
seq_len = kv_lens[i]
111+
assert seq_len > 0
112+
num_blocks = (seq_len + block_size - 1) // block_size
113+
kv_indices.extend(block_tables[i, :num_blocks])
114+
kv_indptr.append(kv_indptr[-1] + num_blocks)
115+
kv_last_page_len = seq_len % block_size
116+
if kv_last_page_len == 0:
117+
kv_last_page_len = block_size
118+
kv_last_page_lens.append(kv_last_page_len)
119+
120+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
121+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
122+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
123+
124+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
125+
wrapper = flashinfer.\
126+
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
127+
wrapper.begin_forward(kv_indptr,
128+
kv_indices,
129+
kv_last_page_lens,
130+
num_query_heads,
131+
num_kv_heads,
132+
head_size,
133+
block_size,
134+
"NONE",
135+
data_type=dtype)
136+
137+
output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap)
138+
139+
ref_output = ref_paged_attn(query=query,
140+
key_cache=key_cache,
141+
value_cache=value_cache,
142+
query_lens=[1] * num_seqs,
143+
kv_lens=kv_lens,
144+
block_tables=block_tables,
145+
scale=scale,
146+
soft_cap=soft_cap)
147+
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
148+
f"{torch.max(torch.abs(output - ref_output))}"
149+
150+
151+
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
152+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
153+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
154+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
155+
@pytest.mark.parametrize("dtype", DTYPES)
156+
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
157+
@torch.inference_mode
158+
def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
159+
num_heads: Tuple[int, int],
160+
head_size: int, dtype: torch.dtype,
161+
block_size: int,
162+
soft_cap: Optional[float]) -> None:
163+
torch.set_default_device("cuda")
164+
torch.cuda.manual_seed_all(0)
165+
num_seqs = len(seq_lens)
166+
query_lens = [x[0] for x in seq_lens]
167+
kv_lens = [x[1] for x in seq_lens]
168+
num_query_heads = num_heads[0]
169+
num_kv_heads = num_heads[1]
170+
assert num_query_heads % num_kv_heads == 0
171+
max_kv_len = max(kv_lens)
172+
scale = head_size**-0.5
173+
174+
query = torch.randn(sum(query_lens),
175+
num_query_heads,
176+
head_size,
177+
dtype=dtype)
178+
key_value_cache = torch.randn(NUM_BLOCKS,
179+
2,
180+
block_size,
181+
num_kv_heads,
182+
head_size,
183+
dtype=dtype)
184+
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
185+
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
186+
187+
# Normalize the scale of the key and value caches to mitigate
188+
# numerical instability.
189+
key_cache /= head_size**0.5
190+
value_cache /= head_size**0.5
191+
192+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
193+
block_tables = torch.randint(0,
194+
NUM_BLOCKS,
195+
(num_seqs, max_num_blocks_per_seq),
196+
dtype=torch.int32)
197+
198+
qo_indptr = [0]
199+
kv_indptr = [0]
200+
kv_indices = []
201+
kv_last_page_lens = []
202+
for i in range(num_seqs):
203+
seq_len = kv_lens[i]
204+
assert seq_len > 0
205+
num_blocks = (seq_len + block_size - 1) // block_size
206+
kv_indices.extend(block_tables[i, :num_blocks])
207+
kv_indptr.append(kv_indptr[-1] + num_blocks)
208+
kv_last_page_len = seq_len % block_size
209+
if kv_last_page_len == 0:
210+
kv_last_page_len = block_size
211+
kv_last_page_lens.append(kv_last_page_len)
212+
qo_indptr.append(qo_indptr[-1] + query_lens[i])
213+
214+
qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
215+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
216+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
217+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
218+
219+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
220+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
221+
workspace_buffer, "NHD")
222+
wrapper.begin_forward(
223+
qo_indptr,
224+
kv_indptr,
225+
kv_indices,
226+
kv_last_page_lens,
227+
num_query_heads,
228+
num_kv_heads,
229+
head_size,
230+
block_size,
231+
)
232+
233+
output = wrapper.forward(
234+
query,
235+
key_value_cache,
236+
logits_soft_cap=soft_cap,
237+
)
238+
239+
ref_output = ref_paged_attn(query=query,
240+
key_cache=key_cache,
241+
value_cache=value_cache,
242+
query_lens=query_lens,
243+
kv_lens=kv_lens,
244+
block_tables=block_tables,
245+
scale=scale,
246+
soft_cap=soft_cap)
247+
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
248+
f"{torch.max(torch.abs(output - ref_output))}"

vllm/attention/backends/flashinfer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
102102
# The data type of the paged kv cache
103103
data_type: torch.dtype = None
104104
device: torch.device = torch.device("cuda")
105+
# Only used by gemma2 model
106+
logits_soft_cap: Optional[float] = None
105107

106108
def __post_init__(self):
107109
# Refer to
@@ -271,15 +273,17 @@ def forward(
271273
else:
272274
assert prefill_meta is not None
273275
assert prefill_meta.prefill_wrapper is not None
274-
output = prefill_meta.prefill_wrapper.forward(query,
275-
kv_cache,
276-
causal=True)
276+
output = prefill_meta.prefill_wrapper.forward(
277+
query,
278+
kv_cache,
279+
logits_soft_cap=attn_metadata.logits_soft_cap,
280+
causal=True)
277281
else:
278282
assert attn_metadata.decode_metadata is not None
279283
assert attn_metadata.decode_metadata.decode_wrapper is not None
280284
output = attn_metadata.decode_metadata.decode_wrapper.forward(
281285
query,
282286
kv_cache,
283287
sm_scale=self.scale,
284-
)
288+
logits_soft_cap=attn_metadata.logits_soft_cap)
285289
return output.view(num_tokens, hidden_size)

vllm/attention/selector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ def get_attn_backend(
7777
return IpexAttnBackend
7878
elif backend == _Backend.FLASHINFER:
7979
logger.info("Using Flashinfer backend.")
80-
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
81-
" please avoid using Flashinfer as the"
82-
"backend when running on llma-2-7b."))
80+
logger.warning(("Flashinfer will be stuck on llama-2-7b,"
81+
" please avoid using Flashinfer as the "
82+
"backend when running on llama-2-7b."))
8383
from vllm.attention.backends.flashinfer import FlashInferBackend
8484
return FlashInferBackend
8585
elif backend == _Backend.PALLAS:

vllm/model_executor/models/gemma2.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3939
from vllm.model_executor.sampling_metadata import SamplingMetadata
4040
from vllm.sequence import IntermediateTensors, SamplerOutput
41-
from vllm.utils import print_warning_once
4241

4342
from .interfaces import SupportsLoRA
4443

@@ -137,12 +136,6 @@ def __init__(self,
137136
dtype=torch.get_default_dtype(),
138137
)
139138

140-
if self.config.attn_logit_softcapping is not None:
141-
print_warning_once(
142-
"Gemma 2 normally uses attention logit soft-capping; "
143-
"soft-capping is currently incompatible with the flash "
144-
"attention kernels, so vLLM removes it to enable speed and "
145-
"efficiency gains of flash attention.")
146139
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
147140
# odd layer, vLLM currently ignores it and uses global attention for
148141
# all layers.

vllm/worker/model_runner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
1616
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
1717
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
18-
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
18+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
1919
except ImportError:
2020
BatchDecodeWithPagedKVCacheWrapper = None
2121
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
@@ -683,6 +683,16 @@ def _prepare_model_input_tensors(
683683
dtype=torch.long,
684684
device=self.device)
685685

686+
logits_soft_cap = getattr(self.model_config.hf_config,
687+
'attn_logit_softcapping', None)
688+
if logits_soft_cap is not None and self.attn_backend.get_name(
689+
) != "flashinfer":
690+
raise ValueError("Please use Flashinfer backend for models with"
691+
"logits_soft_cap (i.e., Gemma-2)."
692+
" Otherwise, the output might be wrong."
693+
" Set Flashinfer backend by "
694+
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
695+
686696
if self.attn_backend.get_name() == "flashinfer":
687697
if len(paged_kv_indptr) > 0:
688698
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
@@ -700,7 +710,6 @@ def _prepare_model_input_tensors(
700710

701711
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
702712
self.model_config.dtype)
703-
704713
attn_metadata = self.attn_backend.make_metadata(
705714
num_prefills=num_prefills,
706715
slot_mapping=slot_mapping_tensor,
@@ -721,7 +730,8 @@ def _prepare_model_input_tensors(
721730
query_start_loc=query_start_loc,
722731
device=self.device,
723732
data_type=kv_cache_dtype,
724-
use_cuda_graph=use_captured_graph)
733+
use_cuda_graph=use_captured_graph,
734+
logits_soft_cap=logits_soft_cap)
725735

726736
else:
727737
attn_metadata = self.attn_backend.make_metadata(
@@ -1196,7 +1206,8 @@ def execute_model(
11961206
if model_input.attn_metadata.use_cuda_graph:
11971207
batch_size = model_input.input_tokens.shape[0]
11981208
model_input.attn_metadata.decode_wrapper = self.graph_runners[
1199-
batch_size].flashinfer_decode_wrapper
1209+
model_input.
1210+
virtual_engine][batch_size].flashinfer_decode_wrapper
12001211
else:
12011212
model_input.attn_metadata.decode_wrapper = \
12021213
self.flashinfer_decode_wrapper

0 commit comments

Comments
 (0)