Skip to content

Commit 34847d9

Browse files
Merge pull request #246 from sbintuitions/vllm_0.10.1.1
upgrade: `vllm==0.10.2`
2 parents 51d1770 + 7e38325 commit 34847d9

File tree

3 files changed

+583
-433
lines changed

3 files changed

+583
-433
lines changed

flexeval/core/language_model/vllm_model.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,16 @@ def _batch_complete_text(
187187
)
188188

189189
from vllm import RequestOutput, SamplingParams
190+
from vllm.inputs import TokensPrompt
190191

191-
prompt_token_ids: list[list[int]] = model_inputs.input_ids
192+
batch_prompt_token_ids: list[list[int]] = model_inputs.input_ids
192193
sampling_params: list[SamplingParams] = []
193194
skip_flag_list: list[bool] = []
194195
for i, input_ids in enumerate(model_inputs.input_ids):
195196
remaining = self.model_limit_tokens - len(input_ids)
196197
instance_gen_kwargs = gen_kwargs.copy()
197198
if remaining <= 0:
198-
prompt_token_ids[i] = input_ids[:1]
199+
batch_prompt_token_ids[i] = input_ids[:1]
199200
instance_gen_kwargs["max_tokens"] = 1
200201
msg = (
201202
f"Received input that is longer than `model_limit_tokens = {self.model_limit_tokens}`. "
@@ -208,7 +209,7 @@ def _batch_complete_text(
208209
skip_flag_list.append(remaining <= 0)
209210

210211
vllm_outputs: list[RequestOutput] = self.llm.generate(
211-
prompt_token_ids=prompt_token_ids,
212+
[TokensPrompt(prompt_token_ids=prompt_token_ids) for prompt_token_ids in batch_prompt_token_ids],
212213
sampling_params=sampling_params,
213214
use_tqdm=False,
214215
)
@@ -309,6 +310,7 @@ def _batch_compute_log_probs(
309310
sequence_length = max([len(input_ids) for input_ids in batch_input_ids])
310311

311312
from vllm import RequestOutput, SamplingParams
313+
from vllm.inputs import TokensPrompt
312314
from vllm.sequence import Logprob
313315

314316
sampling_params = SamplingParams(temperature=0.0, max_tokens=1, prompt_logprobs=1)
@@ -323,7 +325,7 @@ def _batch_compute_log_probs(
323325
for chunk_input_ids in chunk_batch_input_ids
324326
]
325327
chunk_batch_outputs: list[RequestOutput] = self.llm.generate(
326-
prompt_token_ids=chunk_batch_input_ids,
328+
[TokensPrompt(prompt_token_ids=prompt_token_ids) for prompt_token_ids in chunk_batch_input_ids],
327329
sampling_params=sampling_params,
328330
use_tqdm=False,
329331
)

0 commit comments

Comments
 (0)