@@ -965,6 +965,7 @@ def encode(
965
965
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
966
966
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
967
967
pooling_task : PoolingTask = "encode" ,
968
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
968
969
) -> list [PoolingRequestOutput ]:
969
970
...
970
971
@@ -981,6 +982,7 @@ def encode(
981
982
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
982
983
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
983
984
pooling_task : PoolingTask = "encode" ,
985
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
984
986
) -> list [PoolingRequestOutput ]:
985
987
...
986
988
@@ -997,6 +999,7 @@ def encode(
997
999
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
998
1000
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
999
1001
pooling_task : PoolingTask = "encode" ,
1002
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1000
1003
) -> list [PoolingRequestOutput ]:
1001
1004
...
1002
1005
@@ -1014,6 +1017,7 @@ def encode(
1014
1017
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1015
1018
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1016
1019
pooling_task : PoolingTask = "encode" ,
1020
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1017
1021
) -> list [PoolingRequestOutput ]:
1018
1022
...
1019
1023
@@ -1031,6 +1035,7 @@ def encode(
1031
1035
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1032
1036
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1033
1037
pooling_task : PoolingTask = "encode" ,
1038
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1034
1039
) -> list [PoolingRequestOutput ]:
1035
1040
...
1036
1041
@@ -1046,6 +1051,7 @@ def encode(
1046
1051
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1047
1052
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1048
1053
pooling_task : PoolingTask = "encode" ,
1054
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1049
1055
) -> list [PoolingRequestOutput ]:
1050
1056
...
1051
1057
@@ -1066,6 +1072,7 @@ def encode(
1066
1072
lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
1067
1073
prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
1068
1074
pooling_task : PoolingTask = "encode" ,
1075
+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
1069
1076
) -> list [PoolingRequestOutput ]:
1070
1077
"""Apply pooling to the hidden states corresponding to the input
1071
1078
prompts.
@@ -1131,9 +1138,11 @@ def encode(
1131
1138
for pooling_param in pooling_params :
1132
1139
pooling_param .verify (pooling_task , model_config )
1133
1140
1134
- tokenization_kwargs = dict [str , Any ]()
1135
- _validate_truncation_size (model_config .max_model_len ,
1136
- truncate_prompt_tokens , tokenization_kwargs )
1141
+ if tokenization_kwargs is None :
1142
+ tokenization_kwargs = dict [str , Any ]()
1143
+ _validate_truncation_size (model_config .max_model_len ,
1144
+ truncate_prompt_tokens ,
1145
+ tokenization_kwargs )
1137
1146
1138
1147
self ._validate_and_add_requests (
1139
1148
prompts = parsed_prompts ,
0 commit comments