|
11 | 11 | AttentionLayer, AttentionType)
|
12 | 12 | from vllm.attention.backends.utils import CommonAttentionState
|
13 | 13 |
|
14 |
| -# These are the 2 tunable parameters of the paged attention Pallas kernel. |
15 |
| -NUM_QUERIES_PER_BLOCK = 32 |
16 |
| -NUM_KV_PAGES_PER_BLOCK = 128 |
17 |
| - |
18 | 14 |
|
19 | 15 | class PallasAttentionBackend(AttentionBackend):
|
20 | 16 |
|
@@ -115,13 +111,6 @@ def __init__(
|
115 | 111 | tpu_version = torch_xla.tpu.version()
|
116 | 112 | if tpu_version < 4:
|
117 | 113 | raise NotImplementedError("TPU version must be 4 or higher.")
|
118 |
| - # NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB |
119 |
| - # TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK, |
120 |
| - # NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes |
121 |
| - if tpu_version == 4: |
122 |
| - self.vmem_limit_bytes = 16 * 1024 * 1024 |
123 |
| - else: |
124 |
| - self.vmem_limit_bytes = 64 * 1024 * 1024 |
125 | 114 |
|
126 | 115 | def forward(
|
127 | 116 | self,
|
@@ -165,9 +154,12 @@ def forward(
|
165 | 154 | attn_metadata.block_tables,
|
166 | 155 | attn_metadata.query_start_loc,
|
167 | 156 | attn_metadata.num_seqs,
|
168 |
| - num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, |
169 |
| - num_queries_per_block=NUM_QUERIES_PER_BLOCK, |
170 |
| - vmem_limit_bytes=self.vmem_limit_bytes, |
| 157 | + # By default, the system utilizes optimized block size and |
| 158 | + # vmem_limit_bytes parameters from the kernel repository. However, |
| 159 | + # these can be manually adjusted for debugging if necessary. |
| 160 | + num_kv_pages_per_block=None, |
| 161 | + num_queries_per_block=None, |
| 162 | + vmem_limit_bytes=None, |
171 | 163 | use_kernel=True,
|
172 | 164 | sm_scale=self.scale,
|
173 | 165 | sliding_window=self.sliding_window,
|
|
0 commit comments