88import math
99import os
1010from dataclasses import asdict
11- from typing import Any , AsyncIterable , Dict , List , Optional
11+ from typing import Any , AsyncIterable , Dict , List , Optional , Union
1212from uuid import uuid4
1313
1414from model_engine_server .common .config import hmi_config
@@ -839,6 +839,54 @@ def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]:
839839 return tokens
840840
841841
842+ def validate_and_update_completion_params (
843+ inference_framework : LLMInferenceFramework ,
844+ request : Union [CompletionSyncV1Request , CompletionStreamV1Request ],
845+ ) -> Union [CompletionSyncV1Request , CompletionStreamV1Request ]:
846+ # top_k, top_p
847+ if inference_framework in [
848+ LLMInferenceFramework .TEXT_GENERATION_INFERENCE ,
849+ LLMInferenceFramework .VLLM ,
850+ LLMInferenceFramework .LIGHTLLM ,
851+ ]:
852+ if request .temperature == 0 :
853+ if request .top_k not in [- 1 , None ] or request .top_p not in [1.0 , None ]:
854+ raise ObjectHasInvalidValueException (
855+ "top_k and top_p can't be enabled when temperature is 0."
856+ )
857+ if request .top_k == 0 :
858+ raise ObjectHasInvalidValueException (
859+ "top_k needs to be strictly positive, or set it to be -1 / None to disable top_k."
860+ )
861+ if inference_framework == LLMInferenceFramework .TEXT_GENERATION_INFERENCE :
862+ request .top_k = None if request .top_k == - 1 else request .top_k
863+ request .top_p = None if request .top_p == 1.0 else request .top_p
864+ if inference_framework in [LLMInferenceFramework .VLLM , LLMInferenceFramework .LIGHTLLM ]:
865+ request .top_k = - 1 if request .top_k is None else request .top_k
866+ request .top_p = 1.0 if request .top_p is None else request .top_p
867+ else :
868+ if request .top_k or request .top_p :
869+ raise ObjectHasInvalidValueException (
870+ "top_k and top_p are only supported in text-generation-inference, vllm, lightllm."
871+ )
872+
873+ # presence_penalty, frequency_penalty
874+ if inference_framework in [LLMInferenceFramework .VLLM , LLMInferenceFramework .LIGHTLLM ]:
875+ request .presence_penalty = (
876+ 0.0 if request .presence_penalty is None else request .presence_penalty
877+ )
878+ request .frequency_penalty = (
879+ 0.0 if request .frequency_penalty is None else request .frequency_penalty
880+ )
881+ else :
882+ if request .presence_penalty or request .frequency_penalty :
883+ raise ObjectHasInvalidValueException (
884+ "presence_penalty and frequency_penalty are only supported in vllm, lightllm."
885+ )
886+
887+ return request
888+
889+
842890class CompletionSyncV1UseCase :
843891 """
844892 Use case for running a prompt completion on an LLM endpoint.
@@ -983,6 +1031,15 @@ async def execute(
9831031 endpoint_id = model_endpoint .record .id
9841032 )
9851033 endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response (model_endpoint )
1034+ validated_request = validate_and_update_completion_params (
1035+ endpoint_content .inference_framework , request
1036+ )
1037+ if not isinstance (validated_request , CompletionSyncV1Request ):
1038+ raise ValueError (
1039+ f"request has type { validated_request .__class__ .__name__ } , expected type CompletionSyncV1Request"
1040+ )
1041+ request = validated_request
1042+
9861043 if endpoint_content .inference_framework == LLMInferenceFramework .DEEPSPEED :
9871044 args : Any = {
9881045 "prompts" : [request .prompt ],
@@ -1036,6 +1093,10 @@ async def execute(
10361093 if request .temperature > 0 :
10371094 tgi_args ["parameters" ]["temperature" ] = request .temperature
10381095 tgi_args ["parameters" ]["do_sample" ] = True
1096+ tgi_args ["parameters" ]["top_k" ] = request .top_k
1097+ tgi_args ["parameters" ]["top_p" ] = request .top_p
1098+ else :
1099+ tgi_args ["parameters" ]["do_sample" ] = False
10391100
10401101 inference_request = SyncEndpointPredictV1Request (
10411102 args = tgi_args ,
@@ -1064,10 +1125,15 @@ async def execute(
10641125 vllm_args : Any = {
10651126 "prompt" : request .prompt ,
10661127 "max_tokens" : request .max_new_tokens ,
1128+ "presence_penalty" : request .presence_penalty ,
1129+ "frequency_penalty" : request .frequency_penalty ,
10671130 }
10681131 if request .stop_sequences is not None :
10691132 vllm_args ["stop" ] = request .stop_sequences
10701133 vllm_args ["temperature" ] = request .temperature
1134+ if request .temperature > 0 :
1135+ vllm_args ["top_k" ] = request .top_k
1136+ vllm_args ["top_p" ] = request .top_p
10711137 if request .return_token_log_probs :
10721138 vllm_args ["logprobs" ] = 1
10731139
@@ -1098,12 +1164,16 @@ async def execute(
10981164 "inputs" : request .prompt ,
10991165 "parameters" : {
11001166 "max_new_tokens" : request .max_new_tokens ,
1167+ "presence_penalty" : request .presence_penalty ,
1168+ "frequency_penalty" : request .frequency_penalty ,
11011169 },
11021170 }
11031171 # TODO: implement stop sequences
11041172 if request .temperature > 0 :
11051173 lightllm_args ["parameters" ]["temperature" ] = request .temperature
11061174 lightllm_args ["parameters" ]["do_sample" ] = True
1175+ lightllm_args ["top_k" ] = request .top_k
1176+ lightllm_args ["top_p" ] = request .top_p
11071177 else :
11081178 lightllm_args ["parameters" ]["do_sample" ] = False
11091179 if request .return_token_log_probs :
@@ -1172,6 +1242,7 @@ async def execute(
11721242
11731243 request_id = str (uuid4 ())
11741244 add_trace_request_id (request_id )
1245+
11751246 model_endpoints = await self .llm_model_endpoint_service .list_llm_model_endpoints (
11761247 owner = user .team_id , name = model_endpoint_name , order_by = None
11771248 )
@@ -1209,6 +1280,14 @@ async def execute(
12091280 )
12101281
12111282 model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response (model_endpoint )
1283+ validated_request = validate_and_update_completion_params (
1284+ model_content .inference_framework , request
1285+ )
1286+ if not isinstance (validated_request , CompletionStreamV1Request ):
1287+ raise ValueError (
1288+ f"request has type { validated_request .__class__ .__name__ } , expected type CompletionStreamV1Request"
1289+ )
1290+ request = validated_request
12121291
12131292 args : Any = None
12141293 if model_content .inference_framework == LLMInferenceFramework .DEEPSPEED :
@@ -1237,14 +1316,23 @@ async def execute(
12371316 if request .temperature > 0 :
12381317 args ["parameters" ]["temperature" ] = request .temperature
12391318 args ["parameters" ]["do_sample" ] = True
1319+ args ["parameters" ]["top_k" ] = request .top_k
1320+ args ["parameters" ]["top_p" ] = request .top_p
1321+ else :
1322+ args ["parameters" ]["do_sample" ] = False
12401323 elif model_content .inference_framework == LLMInferenceFramework .VLLM :
12411324 args = {
12421325 "prompt" : request .prompt ,
12431326 "max_tokens" : request .max_new_tokens ,
1327+ "presence_penalty" : request .presence_penalty ,
1328+ "frequency_penalty" : request .frequency_penalty ,
12441329 }
12451330 if request .stop_sequences is not None :
12461331 args ["stop" ] = request .stop_sequences
12471332 args ["temperature" ] = request .temperature
1333+ if request .temperature > 0 :
1334+ args ["top_k" ] = request .top_k
1335+ args ["top_p" ] = request .top_p
12481336 if request .return_token_log_probs :
12491337 args ["logprobs" ] = 1
12501338 args ["stream" ] = True
@@ -1253,12 +1341,16 @@ async def execute(
12531341 "inputs" : request .prompt ,
12541342 "parameters" : {
12551343 "max_new_tokens" : request .max_new_tokens ,
1344+ "presence_penalty" : request .presence_penalty ,
1345+ "frequency_penalty" : request .frequency_penalty ,
12561346 },
12571347 }
12581348 # TODO: stop sequences
12591349 if request .temperature > 0 :
12601350 args ["parameters" ]["temperature" ] = request .temperature
12611351 args ["parameters" ]["do_sample" ] = True
1352+ args ["parameters" ]["top_k" ] = request .top_k
1353+ args ["parameters" ]["top_p" ] = request .top_p
12621354 else :
12631355 args ["parameters" ]["do_sample" ] = False
12641356 if request .return_token_log_probs :
0 commit comments