Skip to content

Commit 6431be8

Browse files
[Tests] conftest: Extending VllmRunner and HfRunner to accept token_ids as input (#26295)
Signed-off-by: Yannick Schnider <[email protected]> Signed-off-by: Yannick Schnider <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 4727a8a commit 6431be8

File tree

2 files changed

+65
-65
lines changed

2 files changed

+65
-65
lines changed

tests/conftest.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from vllm.outputs import RequestOutput
5858
from vllm.sampling_params import BeamSearchParams
5959
from vllm.transformers_utils.utils import maybe_model_redirect
60-
from vllm.utils import set_default_torch_num_threads
60+
from vllm.utils import is_list_of, set_default_torch_num_threads
6161

6262
logger = init_logger(__name__)
6363

@@ -406,11 +406,11 @@ def _init(
406406

407407
def get_inputs(
408408
self,
409-
prompts: list[str],
409+
prompts: Union[list[str], list[list[int]]],
410410
images: Optional[PromptImageInput] = None,
411411
videos: Optional[PromptVideoInput] = None,
412412
audios: Optional[PromptAudioInput] = None,
413-
) -> list[Union[BatchFeature, BatchEncoding]]:
413+
) -> list[Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]]:
414414
if images is not None:
415415
assert len(prompts) == len(images)
416416

@@ -420,31 +420,48 @@ def get_inputs(
420420
if audios is not None:
421421
assert len(prompts) == len(audios)
422422

423-
all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
423+
all_inputs: list[
424+
Union[BatchFeature, BatchEncoding, dict[str, torch.Tensor]]
425+
] = []
424426
for i, prompt in enumerate(prompts):
425-
processor_kwargs: dict[str, Any] = {
426-
"text": prompt,
427-
"return_tensors": "pt",
428-
}
429-
if images is not None and (image := images[i]) is not None:
430-
processor_kwargs["images"] = image
431-
if videos is not None and (video := videos[i]) is not None:
432-
processor_kwargs["videos"] = video
433-
if audios is not None and (audio_inputs := audios[i]) is not None:
434-
# HACK - not all processors take sampling_rate; we should
435-
# clean this up in the future.
436-
if len(audio_inputs) == 2:
437-
audio, sr = audio_inputs
438-
processor_kwargs["audio"] = audio
439-
processor_kwargs["sampling_rate"] = sr
440-
else:
441-
processor_kwargs["audio"] = audio_inputs
442-
443-
inputs = self.processor(**processor_kwargs)
444-
if isinstance(inputs, BatchFeature):
445-
inputs = inputs.to(dtype=self.dtype)
446-
447-
all_inputs.append(inputs)
427+
if isinstance(prompt, str):
428+
processor_kwargs: dict[str, Any] = {
429+
"text": prompt,
430+
"return_tensors": "pt",
431+
}
432+
if images is not None and (image := images[i]) is not None:
433+
processor_kwargs["images"] = image
434+
if videos is not None and (video := videos[i]) is not None:
435+
processor_kwargs["videos"] = video
436+
if audios is not None and (audio_inputs := audios[i]) is not None:
437+
# HACK - not all processors take sampling_rate; we should
438+
# clean this up in the future.
439+
if len(audio_inputs) == 2:
440+
audio, sr = audio_inputs
441+
processor_kwargs["audio"] = audio
442+
processor_kwargs["sampling_rate"] = sr
443+
else:
444+
processor_kwargs["audio"] = audio_inputs
445+
446+
inputs = self.processor(**processor_kwargs)
447+
if isinstance(inputs, BatchFeature):
448+
inputs = inputs.to(dtype=self.dtype)
449+
all_inputs.append(inputs)
450+
else:
451+
# check that prompt is (batched) list of integers (token ids)
452+
if not is_list_of(prompt, typ=int, check="all"):
453+
raise ValueError(
454+
"Prompt must be a list of ints corresponding to the prompt token ids."
455+
)
456+
# check that no multimodal input is provided
457+
if images or videos or audios:
458+
raise ValueError(
459+
"When providing prompt token ids multimodal inputs are not supported."
460+
)
461+
input_dict = {
462+
"input_ids": torch.tensor(prompt, dtype=torch.long).unsqueeze(0),
463+
}
464+
all_inputs.append(input_dict)
448465

449466
return all_inputs
450467

@@ -477,7 +494,7 @@ def classify(self, prompts: list[str]) -> list[str]:
477494

478495
def generate(
479496
self,
480-
prompts: list[str],
497+
prompts: Union[list[str], list[list[int]]],
481498
images: Optional[PromptImageInput] = None,
482499
videos: Optional[PromptVideoInput] = None,
483500
audios: Optional[PromptAudioInput] = None,
@@ -505,7 +522,7 @@ def generate(
505522

506523
def generate_greedy(
507524
self,
508-
prompts: list[str],
525+
prompts: Union[list[str], list[list[int]]],
509526
max_tokens: int,
510527
images: Optional[PromptImageInput] = None,
511528
videos: Optional[PromptVideoInput] = None,
@@ -807,7 +824,7 @@ def get_inputs(
807824

808825
def generate(
809826
self,
810-
prompts: Union[list[str], list[torch.Tensor]],
827+
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
811828
sampling_params: SamplingParams,
812829
images: Optional[PromptImageInput] = None,
813830
videos: Optional[PromptVideoInput] = None,
@@ -877,7 +894,7 @@ def generate_w_logprobs(
877894

878895
def generate_greedy(
879896
self,
880-
prompts: Union[list[str], list[torch.Tensor]],
897+
prompts: Union[list[str], list[torch.Tensor], list[list[int]]],
881898
max_tokens: int,
882899
images: Optional[PromptImageInput] = None,
883900
videos: Optional[PromptVideoInput] = None,

tests/v1/e2e/test_context_length.py

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,10 @@
2323
"""
2424

2525
import pytest
26-
import torch
27-
from transformers import AutoModelForCausalLM
2826

27+
from tests.conftest import HfRunner, VllmRunner
2928
from tests.models.utils import check_outputs_equal
3029
from tests.utils import create_new_process_for_each_test
31-
from vllm import LLM, SamplingParams
32-
from vllm.inputs import TokensPrompt
3330

3431

3532
@create_new_process_for_each_test()
@@ -43,6 +40,8 @@
4340
)
4441
def test_max_context_length(
4542
model: str,
43+
vllm_runner: type[VllmRunner],
44+
hf_runner: type[HfRunner],
4645
prompt_len: int,
4746
max_tokens: int,
4847
) -> None:
@@ -57,42 +56,26 @@ def test_max_context_length(
5756
# Construct a prompt of size prompt_len
5857
prompt_ids = [[43] * prompt_len]
5958

60-
# Generate max_tokens new tokens deterministically.
61-
sampling_params = [
62-
SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True)
63-
]
64-
6559
# --- vLLM generation ---
66-
llm = LLM(
67-
model=model,
68-
tokenizer=model,
60+
with vllm_runner(
61+
model_name=model,
62+
tokenizer_name=model,
6963
max_model_len=2048,
7064
max_num_seqs=1,
7165
tensor_parallel_size=1,
72-
)
73-
74-
vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])]
75-
vllm_results = llm.generate(vllm_token_prompts, sampling_params)
76-
77-
vllm_output_ids = vllm_results[0].outputs[0].token_ids
78-
79-
# --- HuggingFace generation ---
80-
with torch.no_grad():
81-
hf_model = AutoModelForCausalLM.from_pretrained(model)
82-
83-
# HF expects a tensor of input ids shaped (batch, seq_len).
84-
hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0)
85-
66+
) as vllm_model:
8667
# Generate max_tokens new tokens deterministically.
87-
hf_generated = hf_model.generate(
88-
hf_input_tokens,
89-
do_sample=False,
90-
min_new_tokens=max_tokens,
91-
max_new_tokens=max_tokens,
92-
)
68+
vllm_outputs = vllm_model.generate_greedy(prompt_ids, max_tokens)
9369

94-
# HF returns the prompt + generated tokens. Slice off the prompt.
95-
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]) :]
70+
# --- HuggingFace generation ---
71+
with hf_runner(
72+
model_name=model,
73+
) as hf_model:
74+
hf_outputs = hf_model.generate_greedy(prompt_ids, max_tokens)
75+
76+
# vLLM and HF runners return prompt + generated tokens. Slice off the prompt.
77+
vllm_output_ids = vllm_outputs[0][0][prompt_len:]
78+
hf_output_ids = hf_outputs[0][0][prompt_len:]
9679

9780
# check that exactly max_tokens tokens were generated with vLLM and HF
9881
assert len(vllm_output_ids) == len(hf_output_ids) == max_tokens

0 commit comments

Comments
 (0)