Skip to content

Commit b1eb4ca

Browse files
authored
[TPU] Update PyTorch/XLA (#16288)
Signed-off-by: Chengji Yao <[email protected]>
1 parent 87b4ac5 commit b1eb4ca

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

requirements/tpu.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ ray[data]
1717
--find-links https://storage.googleapis.com/libtpu-releases/index.html
1818
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
1919
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
20-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22-
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25-
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
20+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
21+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
22+
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
23+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
24+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
25+
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
2626

tests/v1/tpu/test_pallas.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import torch
55

66
from vllm.attention.backends.abstract import AttentionType
7-
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
8-
NUM_QUERIES_PER_BLOCK,
9-
PallasAttentionBackendImpl,
7+
from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl,
108
PallasMetadata)
119

1210

@@ -32,8 +30,6 @@ def test_ragged_paged_attention():
3230
logits_soft_cap=logits_soft_cap,
3331
attn_type=AttentionType.DECODER,
3432
)
35-
mock_vmem_limit_bytes = 1024
36-
attn_impl.vmem_limit_bytes = mock_vmem_limit_bytes
3733

3834
class FakeAttentionLayer:
3935
_k_scale_float: float
@@ -88,9 +84,9 @@ class FakeAttentionLayer:
8884
ANY, # block_tables
8985
ANY, # query_start_loc
9086
ANY, # num_seqs
91-
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
92-
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
93-
vmem_limit_bytes=mock_vmem_limit_bytes,
87+
num_kv_pages_per_block=None,
88+
num_queries_per_block=None,
89+
vmem_limit_bytes=None,
9490
use_kernel=True,
9591
sm_scale=scale,
9692
sliding_window=sliding_window,

0 commit comments

Comments
 (0)