Skip to content

Commit b4ad56c

Browse files
authored
[V1][TPU] Apply the ragged paged attention kernel fix and remove the padding. (#14846)
Signed-off-by: Xiongfei Wei <[email protected]>
1 parent 69698f2 commit b4ad56c

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

requirements/tpu.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ ray[data]
1717
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1818
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
1919
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

vllm/v1/worker/tpu_model_runner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from vllm.sampling_params import SamplingType
2424
from vllm.sequence import IntermediateTensors
2525
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
26-
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
27-
PallasAttentionBackend,
26+
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
2827
PallasMetadata)
2928
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
3029
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -139,10 +138,8 @@ def __init__(
139138
device="cpu")
140139
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
141140

142-
padded_max_num_blocks_per_req = _get_padded_number(
143-
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
144141
self.block_table_cpu = torch.zeros(
145-
(self.max_num_tokens, padded_max_num_blocks_per_req),
142+
(self.max_num_tokens, self.max_num_blocks_per_req),
146143
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
147144
device="cpu")
148145

0 commit comments

Comments
 (0)