Skip to content

Commit 79dc5fa

Browse files
Guided decoding (#476)
* Guided decoding * endpoints * fix * update client * unit tests * fix test * coverage * coverage * fix * try to bump coverage * more tests! * lint
1 parent 44fe4e8 commit 79dc5fa

File tree

12 files changed

+474
-7
lines changed

12 files changed

+474
-7
lines changed

clients/python/llmengine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.0.0b27"
15+
__version__ = "0.0.0b28"
1616

1717
import os
1818
from typing import Sequence

clients/python/llmengine/completion.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import AsyncIterable, Iterator, List, Optional, Union
1+
from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union
22

33
from llmengine.api_engine import APIEngine
44
from llmengine.data_types import (
@@ -43,6 +43,10 @@ async def acreate(
4343
frequency_penalty: Optional[float] = None,
4444
top_k: Optional[int] = None,
4545
top_p: Optional[float] = None,
46+
include_stop_str_in_output: Optional[bool] = False,
47+
guided_json: Optional[Dict[str, Any]] = None,
48+
guided_regex: Optional[str] = None,
49+
guided_choice: Optional[List[str]] = None,
4650
timeout: int = COMPLETION_TIMEOUT,
4751
stream: bool = False,
4852
) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]:
@@ -102,6 +106,18 @@ async def acreate(
102106
Float that controls the cumulative probability of the top tokens to consider.
103107
Range: (0.0, 1.0]. 1.0 means consider all tokens.
104108
109+
include_stop_str_in_output (Optional[bool]):
110+
Whether to include the stop sequence in the output. Default to False.
111+
112+
guided_json (Optional[Dict[str, Any]]):
113+
If specified, the output will follow the JSON schema. For examples see https://json-schema.org/learn/miscellaneous-examples.
114+
115+
guided_regex (Optional[str]):
116+
If specified, the output will follow the regex pattern.
117+
118+
guided_choice (Optional[List[str]]):
119+
If specified, the output will be exactly one of the choices.
120+
105121
timeout (int):
106122
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
107123
@@ -198,6 +214,10 @@ async def _acreate_stream(
198214
frequency_penalty=frequency_penalty,
199215
top_k=top_k,
200216
top_p=top_p,
217+
include_stop_str_in_output=include_stop_str_in_output,
218+
guided_json=guided_json,
219+
guided_regex=guided_regex,
220+
guided_choice=guided_choice,
201221
timeout=timeout,
202222
)
203223

@@ -237,6 +257,10 @@ def create(
237257
frequency_penalty: Optional[float] = None,
238258
top_k: Optional[int] = None,
239259
top_p: Optional[float] = None,
260+
include_stop_str_in_output: Optional[bool] = False,
261+
guided_json: Optional[Dict[str, Any]] = None,
262+
guided_regex: Optional[str] = None,
263+
guided_choice: Optional[List[str]] = None,
240264
timeout: int = COMPLETION_TIMEOUT,
241265
stream: bool = False,
242266
) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]:
@@ -297,6 +321,18 @@ def create(
297321
Float that controls the cumulative probability of the top tokens to consider.
298322
Range: (0.0, 1.0]. 1.0 means consider all tokens.
299323
324+
include_stop_str_in_output (Optional[bool]):
325+
Whether to include the stop sequence in the output. Default to False.
326+
327+
guided_json (Optional[Dict[str, Any]]):
328+
If specified, the output will follow the JSON schema.
329+
330+
guided_regex (Optional[str]):
331+
If specified, the output will follow the regex pattern.
332+
333+
guided_choice (Optional[List[str]]):
334+
If specified, the output will be exactly one of the choices.
335+
300336
timeout (int):
301337
Timeout in seconds. This is the maximum amount of time you are willing to wait for a response.
302338
@@ -396,6 +432,10 @@ def _create_stream(**kwargs):
396432
frequency_penalty=frequency_penalty,
397433
top_k=top_k,
398434
top_p=top_p,
435+
include_stop_str_in_output=include_stop_str_in_output,
436+
guided_json=guided_json,
437+
guided_regex=guided_regex,
438+
guided_choice=guided_choice,
399439
).dict()
400440
response = cls.post_sync(
401441
resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}",

clients/python/llmengine/data_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,10 @@ class CompletionSyncV1Request(BaseModel):
279279
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
280280
top_k: Optional[int] = Field(default=None, ge=-1)
281281
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
282+
include_stop_str_in_output: Optional[bool] = Field(default=False)
283+
guided_json: Optional[Dict[str, Any]] = Field(default=None)
284+
guided_regex: Optional[str] = Field(default=None)
285+
guided_choice: Optional[List[str]] = Field(default=None)
282286

283287

284288
class TokenOutput(BaseModel):
@@ -349,6 +353,10 @@ class CompletionStreamV1Request(BaseModel):
349353
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
350354
top_k: Optional[int] = Field(default=None, ge=-1)
351355
top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0)
356+
include_stop_str_in_output: Optional[bool] = Field(default=False)
357+
guided_json: Optional[Dict[str, Any]] = Field(default=None)
358+
guided_regex: Optional[str] = Field(default=None)
359+
guided_choice: Optional[List[str]] = Field(default=None)
352360

353361

354362
class CompletionStreamOutput(BaseModel):

clients/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "scale-llm-engine"
3-
version = "0.0.0.beta27"
3+
version = "0.0.0.beta28"
44
description = "Scale LLM Engine Python client"
55
license = "Apache-2.0"
66
authors = ["Phil Chen <[email protected]>"]

clients/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
setup(
44
name="scale-llm-engine",
55
python_requires=">=3.7",
6-
version="0.0.0.beta27",
6+
version="0.0.0.beta28",
77
packages=find_packages(),
88
)

docs/guides/completions.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,59 @@ response = Completion.batch_create(
193193
print(response.json())
194194
```
195195

196+
## Guided decoding
197+
198+
Guided decoding is supported by vLLM and backed by [Outlines](https://github.com/outlines-dev/outlines).
199+
It enforces certain token generation patterns by tinkering with the sampling logits.
200+
201+
=== "Guided decoding with regex"
202+
```python
203+
from llmengine import Completion
204+
205+
response = Completion.create(
206+
model="llama-2-7b",
207+
prompt="Hello, my name is",
208+
max_new_tokens=10,
209+
temperature=0.2,
210+
guided_regex="Sean.*",
211+
)
212+
213+
print(response.json())
214+
# {"request_id":"c19f0fae-317e-4f69-8e06-c04189299b9c","output":{"text":"Sean. I'm a 2","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
215+
```
216+
217+
=== "Guided decoding with choice"
218+
```python
219+
from llmengine import Completion
220+
221+
response = Completion.create(
222+
model="llama-2-7b",
223+
prompt="Hello, my name is",
224+
max_new_tokens=10,
225+
temperature=0.2,
226+
guided_choice=["Sean", "Brian", "Tim"],
227+
)
228+
229+
print(response.json())
230+
# {"request_id":"641e2af3-a3e3-4493-98b9-d38115ba0d22","output":{"text":"Sean","num_prompt_tokens":6,"num_completion_tokens":4,"tokens":null}}
231+
```
232+
233+
=== "Guided decoding with JSON schema"
234+
```python
235+
from llmengine import Completion
236+
237+
response = Completion.create(
238+
model="llama-2-7b",
239+
prompt="Hello, my name is",
240+
max_new_tokens=10,
241+
temperature=0.2,
242+
guided_json={"properties":{"myString":{"type":"string"}},"required":["myString"]},
243+
)
244+
245+
print(response.json())
246+
# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
247+
```
248+
196249
## Which model should I use?
197250

198251
See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions.

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,18 @@ class CompletionSyncV1Request(BaseModel):
184184
"""
185185
Whether to include the stop strings in output text.
186186
"""
187+
guided_json: Optional[Dict[str, Any]] = None
188+
"""
189+
JSON schema for guided decoding.
190+
"""
191+
guided_regex: Optional[str] = None
192+
"""
193+
Regex for guided decoding.
194+
"""
195+
guided_choice: Optional[List[str]] = None
196+
"""
197+
Choices for guided decoding.
198+
"""
187199

188200

189201
class TokenOutput(BaseModel):
@@ -248,6 +260,18 @@ class CompletionStreamV1Request(BaseModel):
248260
"""
249261
Whether to include the stop strings in output text.
250262
"""
263+
guided_json: Optional[Dict[str, Any]] = None
264+
"""
265+
JSON schema for guided decoding. Only supported in vllm.
266+
"""
267+
guided_regex: Optional[str] = None
268+
"""
269+
Regex for guided decoding. Only supported in vllm.
270+
"""
271+
guided_choice: Optional[List[str]] = None
272+
"""
273+
Choices for guided decoding. Only supported in vllm.
274+
"""
251275

252276

253277
class CompletionStreamOutput(BaseModel):

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,26 @@ def validate_and_update_completion_params(
13651365
"include_stop_str_in_output is only supported in vllm."
13661366
)
13671367

1368+
guided_count = 0
1369+
if request.guided_choice is not None:
1370+
guided_count += 1
1371+
if request.guided_json is not None:
1372+
guided_count += 1
1373+
if request.guided_regex is not None:
1374+
guided_count += 1
1375+
1376+
if guided_count > 1:
1377+
raise ObjectHasInvalidValueException(
1378+
"Only one of guided_json, guided_choice, guided_regex can be enabled."
1379+
)
1380+
1381+
if (
1382+
request.guided_choice is not None
1383+
or request.guided_regex is not None
1384+
or request.guided_json is not None
1385+
) and not inference_framework == LLMInferenceFramework.VLLM:
1386+
raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.")
1387+
13681388
return request
13691389

13701390

@@ -1656,6 +1676,12 @@ async def execute(
16561676
vllm_args["logprobs"] = 1
16571677
if request.include_stop_str_in_output is not None:
16581678
vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output
1679+
if request.guided_choice is not None:
1680+
vllm_args["guided_choice"] = request.guided_choice
1681+
if request.guided_regex is not None:
1682+
vllm_args["guided_regex"] = request.guided_regex
1683+
if request.guided_json is not None:
1684+
vllm_args["guided_json"] = request.guided_json
16591685

16601686
inference_request = SyncEndpointPredictV1Request(
16611687
args=vllm_args,
@@ -1918,6 +1944,12 @@ async def execute(
19181944
args["logprobs"] = 1
19191945
if request.include_stop_str_in_output is not None:
19201946
args["include_stop_str_in_output"] = request.include_stop_str_in_output
1947+
if request.guided_choice is not None:
1948+
args["guided_choice"] = request.guided_choice
1949+
if request.guided_regex is not None:
1950+
args["guided_regex"] = request.guided_regex
1951+
if request.guided_json is not None:
1952+
args["guided_json"] = request.guided_json
19211953
args["stream"] = True
19221954
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
19231955
args = {
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
ray>=2.9
2-
vllm==0.3.2
1+
vllm==0.3.3
32
pydantic>=2.0

model-engine/model_engine_server/inference/vllm/vllm_server.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
from typing import AsyncGenerator
88

99
import uvicorn
10-
from fastapi import BackgroundTasks, FastAPI, Request
10+
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
1111
from fastapi.responses import Response, StreamingResponse
1212
from vllm.engine.arg_utils import AsyncEngineArgs
1313
from vllm.engine.async_llm_engine import AsyncLLMEngine
14+
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
15+
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
1416
from vllm.sampling_params import SamplingParams
1517
from vllm.utils import random_uuid
1618

@@ -38,7 +40,35 @@ async def generate(request: Request) -> Response:
3840
request_dict = await request.json()
3941
prompt = request_dict.pop("prompt")
4042
stream = request_dict.pop("stream", False)
43+
guided_json = request_dict.pop("guided_json", None)
44+
guided_regex = request_dict.pop("guided_regex", None)
45+
guided_choice = request_dict.pop("guided_choice", None)
4146
sampling_params = SamplingParams(**request_dict)
47+
48+
# Dummy request to get guided decode logit processor
49+
try:
50+
partial_openai_request = OpenAICompletionRequest.model_validate(
51+
{
52+
"model": "",
53+
"prompt": "",
54+
"guided_json": guided_json,
55+
"guided_regex": guided_regex,
56+
"guided_choice": guided_choice,
57+
}
58+
)
59+
except Exception:
60+
raise HTTPException(
61+
status_code=400, detail="Bad request: failed to parse guided decoding parameters."
62+
)
63+
64+
guided_decode_logit_processor = await get_guided_decoding_logits_processor(
65+
partial_openai_request, engine.get_tokenizer()
66+
)
67+
if guided_decode_logit_processor is not None:
68+
if sampling_params.logits_processors is None:
69+
sampling_params.logits_processors = []
70+
sampling_params.logits_processors.append(guided_decode_logit_processor)
71+
4272
request_id = random_uuid()
4373
results_generator = engine.generate(prompt, sampling_params, request_id)
4474

0 commit comments

Comments
 (0)