Skip to content

Commit fadc59c

Browse files
authored
[TPU][V1] Remove ragged attention kernel parameter hard coding (#16041)
Signed-off-by: Chengji Yao <[email protected]>
1 parent 86cbd2e commit fadc59c

File tree

2 files changed

+8
-20
lines changed

2 files changed

+8
-20
lines changed

vllm/v1/attention/backends/pallas.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
AttentionLayer, AttentionType)
1212
from vllm.attention.backends.utils import CommonAttentionState
1313

14-
# These are the 2 tunable parameters of the paged attention Pallas kernel.
15-
NUM_QUERIES_PER_BLOCK = 32
16-
NUM_KV_PAGES_PER_BLOCK = 128
17-
1814

1915
class PallasAttentionBackend(AttentionBackend):
2016

@@ -115,13 +111,6 @@ def __init__(
115111
tpu_version = torch_xla.tpu.version()
116112
if tpu_version < 4:
117113
raise NotImplementedError("TPU version must be 4 or higher.")
118-
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
119-
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
120-
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
121-
if tpu_version == 4:
122-
self.vmem_limit_bytes = 16 * 1024 * 1024
123-
else:
124-
self.vmem_limit_bytes = 64 * 1024 * 1024
125114

126115
def forward(
127116
self,
@@ -165,9 +154,12 @@ def forward(
165154
attn_metadata.block_tables,
166155
attn_metadata.query_start_loc,
167156
attn_metadata.num_seqs,
168-
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
169-
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
170-
vmem_limit_bytes=self.vmem_limit_bytes,
157+
# By default, the system utilizes optimized block size and
158+
# vmem_limit_bytes parameters from the kernel repository. However,
159+
# these can be manually adjusted for debugging if necessary.
160+
num_kv_pages_per_block=None,
161+
num_queries_per_block=None,
162+
vmem_limit_bytes=None,
171163
use_kernel=True,
172164
sm_scale=self.scale,
173165
sliding_window=self.sliding_window,

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from vllm.sampling_params import SamplingType
2525
from vllm.sequence import IntermediateTensors
2626
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
27-
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
28-
PallasAttentionBackend,
27+
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
2928
PallasMetadata)
3029
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
3130
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -155,11 +154,8 @@ def __init__(
155154
dtype=torch.int64,
156155
device="cpu")
157156
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
158-
159-
padded_max_num_blocks_per_req = _get_padded_number(
160-
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
161157
self.block_table_cpu = torch.zeros(
162-
(self.max_num_tokens, padded_max_num_blocks_per_req),
158+
(self.max_num_tokens, self.max_num_blocks_per_req),
163159
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
164160
device="cpu")
165161

0 commit comments

Comments
 (0)