Skip to content

Commit 287ab59

Browse files
Add repetition_penalty, top_k, top_p to Completion (#295)
* add repetition_penalty, top_k, top_p * add frequency_penalty, presence_penalty, add lightllm * add comments * fix * fix Optional, add params validation * remove repetition_penalty * add back optional, update validation function * type check
1 parent 471ede3 commit 287ab59

File tree

4 files changed

+203
-3
lines changed

4 files changed

+203
-3
lines changed

clients/python/llmengine/completion.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ async def acreate(
3333
temperature: float = 0.2,
3434
stop_sequences: Optional[List[str]] = None,
3535
return_token_log_probs: Optional[bool] = False,
36+
presence_penalty: Optional[float] = None,
37+
frequency_penalty: Optional[float] = None,
38+
top_k: Optional[int] = None,
39+
top_p: Optional[float] = None,
3640
timeout: int = COMPLETION_TIMEOUT,
3741
stream: bool = False,
3842
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
@@ -72,6 +76,26 @@ async def acreate(
7276
Whether to return the log probabilities of generated tokens.
7377
When True, the response will include a list of tokens and their log probabilities.
7478
79+
presence_penalty (Optional[float]):
80+
Only supported in vllm, lightllm
81+
Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
82+
https://platform.openai.com/docs/guides/gpt/parameter-details
83+
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.
84+
85+
frequency_penalty (Optional[float]):
86+
Only supported in vllm, lightllm
87+
Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
88+
https://platform.openai.com/docs/guides/gpt/parameter-details
89+
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.
90+
91+
top_k (Optional[int]):
92+
Integer that controls the number of top tokens to consider.
93+
Range: [1, infinity). -1 means consider all tokens.
94+
95+
top_p (Optional[float]):
96+
Float that controls the cumulative probability of the top tokens to consider.
97+
Range: (0.0, 1.0]. 1.0 means consider all tokens.
98+
7599
timeout (int):
76100
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
77101
@@ -164,6 +188,10 @@ async def _acreate_stream(
164188
temperature=temperature,
165189
stop_sequences=stop_sequences,
166190
return_token_log_probs=return_token_log_probs,
191+
presence_penalty=presence_penalty,
192+
frequency_penalty=frequency_penalty,
193+
top_k=top_k,
194+
top_p=top_p,
167195
timeout=timeout,
168196
)
169197

@@ -184,6 +212,10 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse:
184212
temperature=temperature,
185213
stop_sequences=stop_sequences,
186214
return_token_log_probs=return_token_log_probs,
215+
presence_penalty=presence_penalty,
216+
frequency_penalty=frequency_penalty,
217+
top_k=top_k,
218+
top_p=top_p,
187219
)
188220

189221
@classmethod
@@ -195,6 +227,10 @@ def create(
195227
temperature: float = 0.2,
196228
stop_sequences: Optional[List[str]] = None,
197229
return_token_log_probs: Optional[bool] = False,
230+
presence_penalty: Optional[float] = None,
231+
frequency_penalty: Optional[float] = None,
232+
top_k: Optional[int] = None,
233+
top_p: Optional[float] = None,
198234
timeout: int = COMPLETION_TIMEOUT,
199235
stream: bool = False,
200236
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
@@ -235,6 +271,26 @@ def create(
235271
Whether to return the log probabilities of generated tokens.
236272
When True, the response will include a list of tokens and their log probabilities.
237273
274+
presence_penalty (Optional[float]):
275+
Only supported in vllm, lightllm
276+
Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
277+
https://platform.openai.com/docs/guides/gpt/parameter-details
278+
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.
279+
280+
frequency_penalty (Optional[float]):
281+
Only supported in vllm, lightllm
282+
Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
283+
https://platform.openai.com/docs/guides/gpt/parameter-details
284+
Range: [0.0, 2.0]. Higher values encourage the model to use new tokens.
285+
286+
top_k (Optional[int]):
287+
Integer that controls the number of top tokens to consider.
288+
Range: [1, infinity). -1 means consider all tokens.
289+
290+
top_p (Optional[float]):
291+
Float that controls the cumulative probability of the top tokens to consider.
292+
Range: (0.0, 1.0]. 1.0 means consider all tokens.
293+
238294
timeout (int):
239295
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
240296
@@ -317,6 +373,10 @@ def _create_stream(**kwargs):
317373
temperature=temperature,
318374
stop_sequences=stop_sequences,
319375
return_token_log_probs=return_token_log_probs,
376+
presence_penalty=presence_penalty,
377+
frequency_penalty=frequency_penalty,
378+
top_k=top_k,
379+
top_p=top_p,
320380
)
321381

322382
else:
@@ -326,6 +386,10 @@ def _create_stream(**kwargs):
326386
temperature=temperature,
327387
stop_sequences=stop_sequences,
328388
return_token_log_probs=return_token_log_probs,
389+
presence_penalty=presence_penalty,
390+
frequency_penalty=frequency_penalty,
391+
top_k=top_k,
392+
top_p=top_p,
329393
).dict()
330394
response = cls.post_sync(
331395
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",

clients/python/llmengine/data_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,10 @@ class CompletionSyncV1Request(BaseModel):
269269
temperature: float = Field(..., ge=0.0)
270270
stop_sequences: Optional[List[str]] = Field(default=None)
271271
return_token_log_probs: Optional[bool] = Field(default=False)
272+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
273+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
274+
top_k: Optional[int] = Field(default=None, ge=-1)
275+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
272276

273277

274278
class TokenOutput(BaseModel):
@@ -330,6 +334,10 @@ class CompletionStreamV1Request(BaseModel):
330334
temperature: float = Field(..., ge=0.0)
331335
stop_sequences: Optional[List[str]] = Field(default=None)
332336
return_token_log_probs: Optional[bool] = Field(default=False)
337+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
338+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
339+
top_k: Optional[int] = Field(default=None, ge=-1)
340+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
333341

334342

335343
class CompletionStreamOutput(BaseModel):

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class CompletionSyncV1Request(BaseModel):
104104

105105
prompt: str
106106
max_new_tokens: int
107-
temperature: float = Field(ge=0, le=1)
107+
temperature: float = Field(ge=0.0, le=1.0)
108108
"""
109109
Temperature of the sampling. Setting to 0 equals to greedy sampling.
110110
"""
@@ -116,6 +116,24 @@ class CompletionSyncV1Request(BaseModel):
116116
"""
117117
Whether to return the log probabilities of the tokens.
118118
"""
119+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
120+
"""
121+
Only supported in vllm, lightllm
122+
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
123+
"""
124+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
125+
"""
126+
Only supported in vllm, lightllm
127+
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
128+
"""
129+
top_k: Optional[int] = Field(default=None, ge=-1)
130+
"""
131+
Controls the number of top tokens to consider. -1 means consider all tokens.
132+
"""
133+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
134+
"""
135+
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
136+
"""
119137

120138

121139
class TokenOutput(BaseModel):
@@ -145,7 +163,7 @@ class CompletionStreamV1Request(BaseModel):
145163

146164
prompt: str
147165
max_new_tokens: int
148-
temperature: float = Field(ge=0, le=1)
166+
temperature: float = Field(ge=0.0, le=1.0)
149167
"""
150168
Temperature of the sampling. Setting to 0 equals to greedy sampling.
151169
"""
@@ -157,6 +175,24 @@ class CompletionStreamV1Request(BaseModel):
157175
"""
158176
Whether to return the log probabilities of the tokens. Only affects behavior for text-generation-inference models
159177
"""
178+
presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
179+
"""
180+
Only supported in vllm, lightllm
181+
Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty
182+
"""
183+
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
184+
"""
185+
Only supported in vllm, lightllm
186+
Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty
187+
"""
188+
top_k: Optional[int] = Field(default=None, ge=-1)
189+
"""
190+
Controls the number of top tokens to consider. -1 means consider all tokens.
191+
"""
192+
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
193+
"""
194+
Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens.
195+
"""
160196

161197

162198
class CompletionStreamOutput(BaseModel):

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import math
99
import os
1010
from dataclasses import asdict
11-
from typing import Any, AsyncIterable, Dict, List, Optional
11+
from typing import Any, AsyncIterable, Dict, List, Optional, Union
1212
from uuid import uuid4
1313

1414
from model_engine_server.common.config import hmi_config
@@ -839,6 +839,54 @@ def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]:
839839
return tokens
840840

841841

842+
def validate_and_update_completion_params(
843+
inference_framework: LLMInferenceFramework,
844+
request: Union[CompletionSyncV1Request, CompletionStreamV1Request],
845+
) -> Union[CompletionSyncV1Request, CompletionStreamV1Request]:
846+
# top_k, top_p
847+
if inference_framework in [
848+
LLMInferenceFramework.TEXT_GENERATION_INFERENCE,
849+
LLMInferenceFramework.VLLM,
850+
LLMInferenceFramework.LIGHTLLM,
851+
]:
852+
if request.temperature == 0:
853+
if request.top_k not in [-1, None] or request.top_p not in [1.0, None]:
854+
raise ObjectHasInvalidValueException(
855+
"top_k and top_p can't be enabled when temperature is 0."
856+
)
857+
if request.top_k == 0:
858+
raise ObjectHasInvalidValueException(
859+
"top_k needs to be strictly positive, or set it to be -1 / None to disable top_k."
860+
)
861+
if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE:
862+
request.top_k = None if request.top_k == -1 else request.top_k
863+
request.top_p = None if request.top_p == 1.0 else request.top_p
864+
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
865+
request.top_k = -1 if request.top_k is None else request.top_k
866+
request.top_p = 1.0 if request.top_p is None else request.top_p
867+
else:
868+
if request.top_k or request.top_p:
869+
raise ObjectHasInvalidValueException(
870+
"top_k and top_p are only supported in text-generation-inference, vllm, lightllm."
871+
)
872+
873+
# presence_penalty, frequency_penalty
874+
if inference_framework in [LLMInferenceFramework.VLLM, LLMInferenceFramework.LIGHTLLM]:
875+
request.presence_penalty = (
876+
0.0 if request.presence_penalty is None else request.presence_penalty
877+
)
878+
request.frequency_penalty = (
879+
0.0 if request.frequency_penalty is None else request.frequency_penalty
880+
)
881+
else:
882+
if request.presence_penalty or request.frequency_penalty:
883+
raise ObjectHasInvalidValueException(
884+
"presence_penalty and frequency_penalty are only supported in vllm, lightllm."
885+
)
886+
887+
return request
888+
889+
842890
class CompletionSyncV1UseCase:
843891
"""
844892
Use case for running a prompt completion on an LLM endpoint.
@@ -983,6 +1031,15 @@ async def execute(
9831031
endpoint_id=model_endpoint.record.id
9841032
)
9851033
endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)
1034+
validated_request = validate_and_update_completion_params(
1035+
endpoint_content.inference_framework, request
1036+
)
1037+
if not isinstance(validated_request, CompletionSyncV1Request):
1038+
raise ValueError(
1039+
f"request has type {validated_request.__class__.__name__}, expected type CompletionSyncV1Request"
1040+
)
1041+
request = validated_request
1042+
9861043
if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
9871044
args: Any = {
9881045
"prompts": [request.prompt],
@@ -1036,6 +1093,10 @@ async def execute(
10361093
if request.temperature > 0:
10371094
tgi_args["parameters"]["temperature"] = request.temperature
10381095
tgi_args["parameters"]["do_sample"] = True
1096+
tgi_args["parameters"]["top_k"] = request.top_k
1097+
tgi_args["parameters"]["top_p"] = request.top_p
1098+
else:
1099+
tgi_args["parameters"]["do_sample"] = False
10391100

10401101
inference_request = SyncEndpointPredictV1Request(
10411102
args=tgi_args,
@@ -1064,10 +1125,15 @@ async def execute(
10641125
vllm_args: Any = {
10651126
"prompt": request.prompt,
10661127
"max_tokens": request.max_new_tokens,
1128+
"presence_penalty": request.presence_penalty,
1129+
"frequency_penalty": request.frequency_penalty,
10671130
}
10681131
if request.stop_sequences is not None:
10691132
vllm_args["stop"] = request.stop_sequences
10701133
vllm_args["temperature"] = request.temperature
1134+
if request.temperature > 0:
1135+
vllm_args["top_k"] = request.top_k
1136+
vllm_args["top_p"] = request.top_p
10711137
if request.return_token_log_probs:
10721138
vllm_args["logprobs"] = 1
10731139

@@ -1098,12 +1164,16 @@ async def execute(
10981164
"inputs": request.prompt,
10991165
"parameters": {
11001166
"max_new_tokens": request.max_new_tokens,
1167+
"presence_penalty": request.presence_penalty,
1168+
"frequency_penalty": request.frequency_penalty,
11011169
},
11021170
}
11031171
# TODO: implement stop sequences
11041172
if request.temperature > 0:
11051173
lightllm_args["parameters"]["temperature"] = request.temperature
11061174
lightllm_args["parameters"]["do_sample"] = True
1175+
lightllm_args["top_k"] = request.top_k
1176+
lightllm_args["top_p"] = request.top_p
11071177
else:
11081178
lightllm_args["parameters"]["do_sample"] = False
11091179
if request.return_token_log_probs:
@@ -1172,6 +1242,7 @@ async def execute(
11721242

11731243
request_id = str(uuid4())
11741244
add_trace_request_id(request_id)
1245+
11751246
model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints(
11761247
owner=user.team_id, name=model_endpoint_name, order_by=None
11771248
)
@@ -1209,6 +1280,14 @@ async def execute(
12091280
)
12101281

12111282
model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint)
1283+
validated_request = validate_and_update_completion_params(
1284+
model_content.inference_framework, request
1285+
)
1286+
if not isinstance(validated_request, CompletionStreamV1Request):
1287+
raise ValueError(
1288+
f"request has type {validated_request.__class__.__name__}, expected type CompletionStreamV1Request"
1289+
)
1290+
request = validated_request
12121291

12131292
args: Any = None
12141293
if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED:
@@ -1237,14 +1316,23 @@ async def execute(
12371316
if request.temperature > 0:
12381317
args["parameters"]["temperature"] = request.temperature
12391318
args["parameters"]["do_sample"] = True
1319+
args["parameters"]["top_k"] = request.top_k
1320+
args["parameters"]["top_p"] = request.top_p
1321+
else:
1322+
args["parameters"]["do_sample"] = False
12401323
elif model_content.inference_framework == LLMInferenceFramework.VLLM:
12411324
args = {
12421325
"prompt": request.prompt,
12431326
"max_tokens": request.max_new_tokens,
1327+
"presence_penalty": request.presence_penalty,
1328+
"frequency_penalty": request.frequency_penalty,
12441329
}
12451330
if request.stop_sequences is not None:
12461331
args["stop"] = request.stop_sequences
12471332
args["temperature"] = request.temperature
1333+
if request.temperature > 0:
1334+
args["top_k"] = request.top_k
1335+
args["top_p"] = request.top_p
12481336
if request.return_token_log_probs:
12491337
args["logprobs"] = 1
12501338
args["stream"] = True
@@ -1253,12 +1341,16 @@ async def execute(
12531341
"inputs": request.prompt,
12541342
"parameters": {
12551343
"max_new_tokens": request.max_new_tokens,
1344+
"presence_penalty": request.presence_penalty,
1345+
"frequency_penalty": request.frequency_penalty,
12561346
},
12571347
}
12581348
# TODO: stop sequences
12591349
if request.temperature > 0:
12601350
args["parameters"]["temperature"] = request.temperature
12611351
args["parameters"]["do_sample"] = True
1352+
args["parameters"]["top_k"] = request.top_k
1353+
args["parameters"]["top_p"] = request.top_p
12621354
else:
12631355
args["parameters"]["do_sample"] = False
12641356
if request.return_token_log_probs:

0 commit comments

Comments
 (0)