Skip to content

Commit d27f2df

Browse files
authored
guided decoding with grammar (#488)
* support guided decoding with grammar * 0.4.1 fixes
1 parent b7284df commit d27f2df

File tree

6 files changed

+42
-6
lines changed

6 files changed

+42
-6
lines changed

docs/guides/completions.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,22 @@ print(response.json())
246246
# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}}
247247
```
248248

249+
=== "Guided decoding with Context-Free Grammar"
250+
251+
```python
252+
from llmengine import Completion
253+
254+
response = Completion.create(
255+
model="llama-2-7b",
256+
prompt="Hello, my name is",
257+
max_new_tokens=10,
258+
temperature=0.2,
259+
guided_grammar="start: \"John\""
260+
)
261+
262+
print(response.json())
263+
# {"request_id": "34621b44-c655-402c-a459-f108b3e49b12", "output": {"text": "John", "num_prompt_tokens": 6, "num_completion_tokens": 4, "tokens": None}}
264+
249265
## Which model should I use?
250266

251267
See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions.

model-engine/model_engine_server/common/dtos/llms.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,19 @@ class CompletionSyncV1Request(BaseModel):
186186
"""
187187
guided_json: Optional[Dict[str, Any]] = None
188188
"""
189-
JSON schema for guided decoding.
189+
JSON schema for guided decoding. Only supported in vllm.
190190
"""
191191
guided_regex: Optional[str] = None
192192
"""
193-
Regex for guided decoding.
193+
Regex for guided decoding. Only supported in vllm.
194194
"""
195195
guided_choice: Optional[List[str]] = None
196196
"""
197-
Choices for guided decoding.
197+
Choices for guided decoding. Only supported in vllm.
198+
"""
199+
guided_grammar: Optional[str] = None
200+
"""
201+
Context-free grammar for guided decoding. Only supported in vllm.
198202
"""
199203

200204

@@ -272,6 +276,10 @@ class CompletionStreamV1Request(BaseModel):
272276
"""
273277
Choices for guided decoding. Only supported in vllm.
274278
"""
279+
guided_grammar: Optional[str] = None
280+
"""
281+
Context-free grammar for guided decoding. Only supported in vllm.
282+
"""
275283

276284

277285
class CompletionStreamOutput(BaseModel):

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,16 +1381,19 @@ def validate_and_update_completion_params(
13811381
guided_count += 1
13821382
if request.guided_regex is not None:
13831383
guided_count += 1
1384+
if request.guided_grammar is not None:
1385+
guided_count += 1
13841386

13851387
if guided_count > 1:
13861388
raise ObjectHasInvalidValueException(
1387-
"Only one of guided_json, guided_choice, guided_regex can be enabled."
1389+
"Only one of guided_json, guided_choice, guided_regex, guided_grammar can be enabled."
13881390
)
13891391

13901392
if (
13911393
request.guided_choice is not None
13921394
or request.guided_regex is not None
13931395
or request.guided_json is not None
1396+
or request.guided_grammar is not None
13941397
) and not inference_framework == LLMInferenceFramework.VLLM:
13951398
raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.")
13961399

@@ -1691,6 +1694,8 @@ async def execute(
16911694
vllm_args["guided_regex"] = request.guided_regex
16921695
if request.guided_json is not None:
16931696
vllm_args["guided_json"] = request.guided_json
1697+
if request.guided_grammar is not None:
1698+
vllm_args["guided_grammar"] = request.guided_grammar
16941699

16951700
inference_request = SyncEndpointPredictV1Request(
16961701
args=vllm_args,
@@ -1959,6 +1964,8 @@ async def execute(
19591964
args["guided_regex"] = request.guided_regex
19601965
if request.guided_json is not None:
19611966
args["guided_json"] = request.guided_json
1967+
if request.guided_grammar is not None:
1968+
args["guided_grammar"] = request.guided_grammar
19621969
args["stream"] = True
19631970
elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM:
19641971
args = {
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
vllm==0.4.0.post1
1+
vllm==0.4.1
22
pydantic>=2.0

model-engine/model_engine_server/inference/vllm/vllm_server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ async def generate(request: Request) -> Response:
4545
guided_json = request_dict.pop("guided_json", None)
4646
guided_regex = request_dict.pop("guided_regex", None)
4747
guided_choice = request_dict.pop("guided_choice", None)
48+
guided_grammar = request_dict.pop("guided_grammar", None)
4849
sampling_params = SamplingParams(**request_dict)
4950

5051
# Dummy request to get guided decode logit processor
@@ -56,15 +57,17 @@ async def generate(request: Request) -> Response:
5657
"guided_json": guided_json,
5758
"guided_regex": guided_regex,
5859
"guided_choice": guided_choice,
60+
"guided_grammar": guided_grammar,
5961
}
6062
)
6163
except Exception:
6264
raise HTTPException(
6365
status_code=400, detail="Bad request: failed to parse guided decoding parameters."
6466
)
6567

68+
guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend
6669
guided_decode_logit_processor = await get_guided_decoding_logits_processor(
67-
partial_openai_request, engine.get_tokenizer()
70+
guided_decoding_backend, partial_openai_request, await engine.get_tokenizer()
6871
)
6972
if guided_decode_logit_processor is not None:
7073
if sampling_params.logits_processors is None:

model-engine/tests/unit/domain/test_llm_use_cases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,11 +1108,13 @@ async def test_validate_and_update_completion_params():
11081108
completion_sync_request.guided_regex = ""
11091109
completion_sync_request.guided_json = {}
11101110
completion_sync_request.guided_choice = [""]
1111+
completion_sync_request.guided_grammar = ""
11111112
with pytest.raises(ObjectHasInvalidValueException):
11121113
validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request)
11131114

11141115
completion_sync_request.guided_regex = None
11151116
completion_sync_request.guided_choice = None
1117+
completion_sync_request.guided_grammar = None
11161118
with pytest.raises(ObjectHasInvalidValueException):
11171119
validate_and_update_completion_params(
11181120
LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request

0 commit comments

Comments
 (0)