Skip to content

Commit 3cc9af8

Browse files
authored
[TPU][V1] Disable per-request seed/Generator (#16172)
Signed-off-by: NickLucche <[email protected]>
1 parent 7cd0bd7 commit 3cc9af8

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

tests/v1/tpu/test_sampler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,8 @@ def test_sampler_different(model_name: str):
3434
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
3535
output2 = llm.generate(prompts, sampling_params)
3636
assert output[0].outputs[0].text != output2[0].outputs[0].text
37+
38+
with pytest.raises(ValueError):
39+
# Unsupported `seed` param.
40+
sampling_params = SamplingParams(temperature=0.3, seed=42)
41+
output2 = llm.generate(prompts, sampling_params)

vllm/platforms/tpu.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import vllm.envs as envs
88
from vllm.inputs import PromptType
99
from vllm.logger import init_logger
10-
from vllm.sampling_params import SamplingParams
10+
from vllm.sampling_params import SamplingParams, SamplingType
1111

1212
from .interface import Platform, PlatformEnum, _Backend
1313

@@ -145,7 +145,10 @@ def validate_request(
145145
params: Union[SamplingParams, PoolingParams],
146146
) -> None:
147147
"""Raises if this request is unsupported on this platform"""
148-
if isinstance(params,
149-
SamplingParams) and params.guided_decoding is not None:
150-
raise ValueError("Structured output is not supported on "
151-
f"{cls.device_name}.")
148+
if isinstance(params, SamplingParams):
149+
if params.guided_decoding is not None:
150+
raise ValueError("Structured output is not supported on "
151+
f"{cls.device_name}.")
152+
if params.sampling_type == SamplingType.RANDOM_SEED:
153+
raise ValueError(
154+
"Torch XLA does not support per-request seed.")

vllm/v1/sample/tpu/metadata.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata:
3333
# Greedy sampling flag for compiling single xla graph.
3434
all_greedy: bool = True
3535

36-
# Generator not supported by xla
37-
generators: dict[int,
38-
torch.Generator] = field(default_factory=lambda: dict())
39-
4036
# unsupported, you need to return an extra tensor of static size BxV
4137
max_num_logprobs = None
4238

@@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata:
5753
allowed_token_ids_mask = None
5854
bad_words_token_ids = None
5955

56+
# Generator not supported by xla
57+
_generators: dict[int,
58+
torch.Generator] = field(default_factory=lambda: dict())
59+
60+
@property
61+
def generators(self) -> dict[int, torch.Generator]:
62+
# Generator not supported by torch/xla. This field must be immutable.
63+
return self._generators
64+
6065
@classmethod
6166
def from_input_batch(
6267
cls,
@@ -109,5 +114,4 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor:
109114
top_p=None, # input_batch.top_p[:padded_num_reqs],
110115
top_k=None, # input_batch.top_k[:padded_num_reqs],
111116
min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(
112-
xla_device),
113-
generators=input_batch.generators)
117+
xla_device))

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from vllm.multimodal import MULTIMODAL_REGISTRY
2424
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
2525
from vllm.multimodal.utils import group_mm_inputs_by_modality
26-
from vllm.sampling_params import SamplingType
2726
from vllm.sequence import IntermediateTensors
2827
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2928
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
@@ -267,11 +266,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
267266
for new_req_data in scheduler_output.scheduled_new_reqs:
268267
req_id = new_req_data.req_id
269268
sampling_params = new_req_data.sampling_params
270-
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
271-
generator = torch.Generator(device=self.device)
272-
generator.manual_seed(sampling_params.seed)
273-
else:
274-
generator = None
275269

276270
self.requests[req_id] = CachedRequestState(
277271
req_id=req_id,
@@ -280,7 +274,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
280274
mm_inputs=new_req_data.mm_inputs,
281275
mm_positions=new_req_data.mm_positions,
282276
sampling_params=sampling_params,
283-
generator=generator,
277+
generator=None,
284278
block_ids=new_req_data.block_ids,
285279
num_computed_tokens=new_req_data.num_computed_tokens,
286280
output_token_ids=[],

0 commit comments

Comments
 (0)