57
57
from vllm .outputs import RequestOutput
58
58
from vllm .sampling_params import BeamSearchParams
59
59
from vllm .transformers_utils .utils import maybe_model_redirect
60
- from vllm .utils import set_default_torch_num_threads
60
+ from vllm .utils import is_list_of , set_default_torch_num_threads
61
61
62
62
logger = init_logger (__name__ )
63
63
@@ -406,11 +406,11 @@ def _init(
406
406
407
407
def get_inputs (
408
408
self ,
409
- prompts : list [str ],
409
+ prompts : Union [ list [str ], list [ list [ int ]] ],
410
410
images : Optional [PromptImageInput ] = None ,
411
411
videos : Optional [PromptVideoInput ] = None ,
412
412
audios : Optional [PromptAudioInput ] = None ,
413
- ) -> list [Union [BatchFeature , BatchEncoding ]]:
413
+ ) -> list [Union [BatchFeature , BatchEncoding , dict [ str , torch . Tensor ] ]]:
414
414
if images is not None :
415
415
assert len (prompts ) == len (images )
416
416
@@ -420,31 +420,48 @@ def get_inputs(
420
420
if audios is not None :
421
421
assert len (prompts ) == len (audios )
422
422
423
- all_inputs : list [Union [BatchFeature , BatchEncoding ]] = []
423
+ all_inputs : list [
424
+ Union [BatchFeature , BatchEncoding , dict [str , torch .Tensor ]]
425
+ ] = []
424
426
for i , prompt in enumerate (prompts ):
425
- processor_kwargs : dict [str , Any ] = {
426
- "text" : prompt ,
427
- "return_tensors" : "pt" ,
428
- }
429
- if images is not None and (image := images [i ]) is not None :
430
- processor_kwargs ["images" ] = image
431
- if videos is not None and (video := videos [i ]) is not None :
432
- processor_kwargs ["videos" ] = video
433
- if audios is not None and (audio_inputs := audios [i ]) is not None :
434
- # HACK - not all processors take sampling_rate; we should
435
- # clean this up in the future.
436
- if len (audio_inputs ) == 2 :
437
- audio , sr = audio_inputs
438
- processor_kwargs ["audio" ] = audio
439
- processor_kwargs ["sampling_rate" ] = sr
440
- else :
441
- processor_kwargs ["audio" ] = audio_inputs
442
-
443
- inputs = self .processor (** processor_kwargs )
444
- if isinstance (inputs , BatchFeature ):
445
- inputs = inputs .to (dtype = self .dtype )
446
-
447
- all_inputs .append (inputs )
427
+ if isinstance (prompt , str ):
428
+ processor_kwargs : dict [str , Any ] = {
429
+ "text" : prompt ,
430
+ "return_tensors" : "pt" ,
431
+ }
432
+ if images is not None and (image := images [i ]) is not None :
433
+ processor_kwargs ["images" ] = image
434
+ if videos is not None and (video := videos [i ]) is not None :
435
+ processor_kwargs ["videos" ] = video
436
+ if audios is not None and (audio_inputs := audios [i ]) is not None :
437
+ # HACK - not all processors take sampling_rate; we should
438
+ # clean this up in the future.
439
+ if len (audio_inputs ) == 2 :
440
+ audio , sr = audio_inputs
441
+ processor_kwargs ["audio" ] = audio
442
+ processor_kwargs ["sampling_rate" ] = sr
443
+ else :
444
+ processor_kwargs ["audio" ] = audio_inputs
445
+
446
+ inputs = self .processor (** processor_kwargs )
447
+ if isinstance (inputs , BatchFeature ):
448
+ inputs = inputs .to (dtype = self .dtype )
449
+ all_inputs .append (inputs )
450
+ else :
451
+ # check that prompt is (batched) list of integers (token ids)
452
+ if not is_list_of (prompt , typ = int , check = "all" ):
453
+ raise ValueError (
454
+ "Prompt must be a list of ints corresponding to the prompt token ids."
455
+ )
456
+ # check that no multimodal input is provided
457
+ if images or videos or audios :
458
+ raise ValueError (
459
+ "When providing prompt token ids multimodal inputs are not supported."
460
+ )
461
+ input_dict = {
462
+ "input_ids" : torch .tensor (prompt , dtype = torch .long ).unsqueeze (0 ),
463
+ }
464
+ all_inputs .append (input_dict )
448
465
449
466
return all_inputs
450
467
@@ -477,7 +494,7 @@ def classify(self, prompts: list[str]) -> list[str]:
477
494
478
495
def generate (
479
496
self ,
480
- prompts : list [str ],
497
+ prompts : Union [ list [str ], list [ list [ int ]] ],
481
498
images : Optional [PromptImageInput ] = None ,
482
499
videos : Optional [PromptVideoInput ] = None ,
483
500
audios : Optional [PromptAudioInput ] = None ,
@@ -505,7 +522,7 @@ def generate(
505
522
506
523
def generate_greedy (
507
524
self ,
508
- prompts : list [str ],
525
+ prompts : Union [ list [str ], list [ list [ int ]] ],
509
526
max_tokens : int ,
510
527
images : Optional [PromptImageInput ] = None ,
511
528
videos : Optional [PromptVideoInput ] = None ,
@@ -807,7 +824,7 @@ def get_inputs(
807
824
808
825
def generate (
809
826
self ,
810
- prompts : Union [list [str ], list [torch .Tensor ]],
827
+ prompts : Union [list [str ], list [torch .Tensor ], list [ list [ int ]] ],
811
828
sampling_params : SamplingParams ,
812
829
images : Optional [PromptImageInput ] = None ,
813
830
videos : Optional [PromptVideoInput ] = None ,
@@ -877,7 +894,7 @@ def generate_w_logprobs(
877
894
878
895
def generate_greedy (
879
896
self ,
880
- prompts : Union [list [str ], list [torch .Tensor ]],
897
+ prompts : Union [list [str ], list [torch .Tensor ], list [ list [ int ]] ],
881
898
max_tokens : int ,
882
899
images : Optional [PromptImageInput ] = None ,
883
900
videos : Optional [PromptVideoInput ] = None ,
0 commit comments