Skip to content

Commit 4098b72

Browse files
authored
[Bugfix][TPU][V1] Fix recompilation (#15553)
Signed-off-by: NickLucche <[email protected]>
1 parent 46450b8 commit 4098b72

File tree

4 files changed

+15
-74
lines changed

4 files changed

+15
-74
lines changed

.buildkite/run-tpu-v1-test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ docker run --privileged --net host --shm-size=16G -it \
3232
&& echo TEST_5 \
3333
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
3434
&& echo TEST_6 \
35-
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \
35+
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \
36+
&& echo TEST_7 \
37+
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
3638

3739

3840
# TODO: This test fails because it uses RANDOM_SEED sampling

tests/v1/tpu/test_sampler.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import tempfile
3-
from time import time
4-
52
import pytest
63

74
from vllm import LLM, envs
@@ -15,60 +12,6 @@
1512
)
1613

1714

18-
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
19-
@pytest.mark.skipif(not current_platform.is_tpu(),
20-
reason="This test needs a TPU")
21-
def test_sampler_compilation(model_name: str, monkeypatch):
22-
"""
23-
Check that no recompilation happens despite changing sampling parameters.
24-
We can't read XLA metrics from the engine process, hence we measure time.
25-
"""
26-
with tempfile.TemporaryDirectory() as temp_dir:
27-
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
28-
# Compiling model init may still take some time, enforce_eager to skip.
29-
llm = LLM(model_name,
30-
enforce_eager=True,
31-
max_num_seqs=16,
32-
max_model_len=1024,
33-
gpu_memory_utilization=0.5)
34-
prompts = [
35-
"A robot may not injure a human being",
36-
"It is only with the heart that one can see rightly;",
37-
]
38-
# First inference should be slow
39-
sampling_params = SamplingParams(
40-
temperature=0.7,
41-
# top_p=0.6, # TODO too slow!
42-
top_k=10,
43-
min_p=0.2,
44-
max_tokens=16)
45-
s = time()
46-
_ = llm.generate(prompts, sampling_params)
47-
run1 = time() - s
48-
49-
# Second request with different params, but for which we
50-
# compiled for in previous eager iteration.
51-
sampling_params = SamplingParams(temperature=0.1,
52-
top_k=12,
53-
min_p=0.8,
54-
max_tokens=24)
55-
s = time()
56-
_ = llm.generate(prompts, sampling_params)
57-
run2 = time() - s
58-
# Much faster after compiling
59-
assert run1 * 0.1 > run2
60-
print("TIMES", run1, run2)
61-
62-
# Third request with min_p set to "None". It will not trigger
63-
# recompilation as a default 0 value will be used.
64-
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
65-
s = time()
66-
_ = llm.generate(prompts, sampling_params)
67-
run3 = time() - s
68-
assert run1 * 0.1 > run3
69-
print("TIMES", run1, run3)
70-
71-
7215
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
7316
@pytest.mark.skipif(not current_platform.is_tpu(),
7417
reason="This test needs a TPU")
@@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
7720
Test significantly different sampling params to assert the model produces
7821
different results.
7922
"""
80-
llm = LLM(
81-
model_name,
82-
enforce_eager=True,
83-
max_num_seqs=1,
84-
max_model_len=64,
85-
# TODO: setting to 0.5 or it will go OOM
86-
gpu_memory_utilization=0.5)
23+
llm = LLM(model_name,
24+
enforce_eager=False,
25+
max_num_seqs=1,
26+
max_model_len=512,
27+
max_num_batched_tokens=512)
8728
prompts = [
8829
"Write a short story about a robot that dreams for the first time."
8930
]

vllm/v1/sample/tpu/metadata.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
8888
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
8989
# Pad value is the default one.
9090
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
91+
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
9192
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
9293

9394
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
@@ -101,13 +102,6 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
101102
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
102103
DEFAULT_SAMPLING_PARAMS["min_p"])
103104

104-
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
105-
# input_batch.frequency_penalties)
106-
# copy_slice(input_batch.presence_penalties_cpu_tensor,
107-
# input_batch.presence_penalties)
108-
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
109-
# input_batch.repetition_penalties)
110-
111105
xm.mark_step()
112106
xm.wait_device_ops()
113107

vllm/v1/worker/tpu_model_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ def __init__(
8888
self.max_model_len = model_config.max_model_len
8989
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
9090
self.max_num_tokens = scheduler_config.max_num_batched_tokens
91+
# InputBatch needs to work with sampling tensors greater than padding
92+
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
9193
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
9294

9395
# Model-related.
@@ -788,6 +790,7 @@ def capture_model(self) -> None:
788790
dummy_hidden = torch.randn((num_tokens, hsize),
789791
device=device,
790792
dtype=torch.bfloat16)
793+
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
791794
while True:
792795
indices = torch.zeros(
793796
num_reqs_to_sample,
@@ -804,7 +807,9 @@ def capture_model(self) -> None:
804807
out = out.cpu()
805808
if num_reqs_to_sample >= self.max_num_reqs:
806809
break
807-
num_reqs_to_sample *= 2
810+
# Make sure to compile the `max_num_reqs` upper-limit case
811+
num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
812+
num_reqs_to_sample + 1, self.max_num_reqs)
808813
xm.wait_device_ops()
809814
end = time.perf_counter()
810815
logger.info("Compilation finished in in %.2f [secs].", end - start)
@@ -897,7 +902,6 @@ def forward(
897902

898903
return hidden_states
899904

900-
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
901905
def sample_from_hidden(
902906
self,
903907
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)