5151from vllm .transformers_utils .tokenizer import (AnyTokenizer , MistralTokenizer ,
5252 get_cached_tokenizer )
5353from vllm .usage .usage_lib import UsageContext
54- from vllm .utils import Counter , Device , is_list_of
54+ from vllm .utils import Counter , Device , as_iter , is_list_of
5555from vllm .v1 .sample .logits_processor import LogitsProcessor
5656
5757if TYPE_CHECKING :
@@ -364,14 +364,6 @@ def generate(
364364 # Use default sampling params.
365365 sampling_params = self .get_default_sampling_params ()
366366
367- tokenization_kwargs : dict [str , Any ] = {}
368- truncate_prompt_tokens = None
369- if isinstance (sampling_params , SamplingParams ):
370- truncate_prompt_tokens = sampling_params .truncate_prompt_tokens
371-
372- _validate_truncation_size (model_config .max_model_len ,
373- truncate_prompt_tokens , tokenization_kwargs )
374-
375367 # Add any modality specific loras to the corresponding prompts
376368 lora_request = self ._get_modality_specific_lora_reqs (
377369 prompts , lora_request )
@@ -381,7 +373,6 @@ def generate(
381373 params = sampling_params ,
382374 use_tqdm = use_tqdm ,
383375 lora_request = lora_request ,
384- tokenization_kwargs = tokenization_kwargs ,
385376 priority = priority ,
386377 )
387378
@@ -871,6 +862,8 @@ def encode(
871862 If `False`, no progress bar is created.
872863 lora_request: LoRA request to use for generation, if any.
873864 pooling_task: Override the pooling task to use.
865+ tokenization_kwargs: overrides tokenization_kwargs set in
866+ pooling_params
874867
875868 Returns:
876869 A list of `PoolingRequestOutput` objects containing the
@@ -916,24 +909,17 @@ def encode(
916909 # Use default pooling params.
917910 pooling_params = PoolingParams ()
918911
919- if isinstance (pooling_params , PoolingParams ):
920- pooling_params .verify (pooling_task , model_config )
921- else :
922- for pooling_param in pooling_params :
923- pooling_param .verify (pooling_task , model_config )
924-
925- if tokenization_kwargs is None :
926- tokenization_kwargs = dict [str , Any ]()
927- _validate_truncation_size (model_config .max_model_len ,
928- truncate_prompt_tokens ,
929- tokenization_kwargs )
912+ for param in as_iter (pooling_params ):
913+ param .verify (pooling_task , model_config )
914+ # for backwards compatibility
915+ if truncate_prompt_tokens is not None :
916+ param .truncate_prompt_tokens = truncate_prompt_tokens
930917
931918 self ._validate_and_add_requests (
932919 prompts = prompts ,
933920 params = pooling_params ,
934921 use_tqdm = use_tqdm ,
935922 lora_request = lora_request ,
936- tokenization_kwargs = tokenization_kwargs ,
937923 )
938924
939925 outputs = self ._run_engine (use_tqdm = use_tqdm )
@@ -1385,7 +1371,6 @@ def _validate_and_add_requests(
13851371 * ,
13861372 use_tqdm : Union [bool , Callable [..., tqdm ]] = True ,
13871373 lora_request : Optional [Union [Sequence [LoRARequest ], LoRARequest ]],
1388- tokenization_kwargs : Optional [dict [str , Any ]] = None ,
13891374 priority : Optional [list [int ]] = None ,
13901375 ) -> None :
13911376 if isinstance (prompts , (str , dict )):
@@ -1412,7 +1397,17 @@ def _validate_and_add_requests(
14121397 tqdm_func = use_tqdm if callable (use_tqdm ) else tqdm
14131398 it = tqdm_func (it , desc = "Adding requests" )
14141399
1400+ model_config = self .llm_engine .model_config
1401+
14151402 for i , prompt in enumerate (it ):
1403+
1404+ param = params [i ] if isinstance (params , Sequence ) else params
1405+
1406+ tokenization_kwargs : dict [str , Any ] = {}
1407+ _validate_truncation_size (model_config .max_model_len ,
1408+ param .truncate_prompt_tokens ,
1409+ tokenization_kwargs )
1410+
14161411 self ._add_request (
14171412 prompt ,
14181413 params [i ] if isinstance (params , Sequence ) else params ,
0 commit comments