diff --git a/python/openai/README.md b/python/openai/README.md index 51d8c5db72..60645bfa1c 100644 --- a/python/openai/README.md +++ b/python/openai/README.md @@ -293,6 +293,80 @@ See the [vLLM documentation](https://github.com/triton-inference-server/vllm_backend/blob/main/docs/llama_multi_lora_tutorial.md) on how to serve a model with LoRA adapters. +### Guided Decoding + +The OpenAI frontend supports guided decoding to constrain model outputs to specific formats. Three types of guided decoding are available through the `guided_decoding_guide_type` parameter: `json` , `regex`, `choice` + +**JSON Schema Example:** +```python +from openai import OpenAI +from pydantic import BaseModel +from enum import Enum + +client = OpenAI(base_url="http://localhost:9000/v1", api_key="EMPTY") + +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + +json_schema = CarDescription.model_json_schema() + +completion = client.chat.completions.create( + model="llama-3.1-8b-instruct", # vLLM model name + messages=[{"role": "user", "content": "Generate a 90's iconic car"}], + max_tokens=100, + extra_body={ + "guided_decoding_guide_type": "json", + "guided_decoding_guide": json_schema, + }, +) +print(completion.choices[0].message.content) +# Output: { "brand": "Chevrolet", "model": "Silverado", "car_type": "SUV" } +``` + +**Regex Pattern Example:** +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:9000/v1", api_key="EMPTY") + +completion = client.chat.completions.create( + model="llama-3.1-8b-instruct", # vLLM model name + messages=[{"role": "user", "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com"}], + extra_body={ + "guided_decoding_guide_type": "regex", + "guided_decoding_guide": "\\w+@\\w+\\.com", + }, +) +print(completion.choices[0].message.content) +# Output: alan777@enigma.com +``` + +**Choice-based Example:** +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:9000/v1", api_key="EMPTY") + +completion = client.chat.completions.create( + model="llama-3.1-8b-instruct", + messages=[{"role": "user", "content": "What's the sentiment: 'I love this!'"}], + extra_body={ + "guided_decoding_guide_type": "choice", + "guided_decoding_guide": ["positive", "negative", "neutral"], + }, +) +print(completion.choices[0].message.content) +# Output: positive +``` + + ## TensorRT-LLM 0. Prepare your model repository for a TensorRT-LLM model, build the engine, etc. You can try any of the following options: @@ -373,6 +447,86 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ The other examples should be the same as vLLM, except that you should set `MODEL="tensorrt_llm_bls"` or `MODEL="ensemble"`, everywhere applicable as seen in the example request above. +### Guided Decoding + +The OpenAI frontend supports guided decoding to constrain model outputs to specific formats. Four types of guided decoding are available through the `guided_decoding_guide_type` parameter: `json_schema`, `json`, `regex`, and `ebnf_grammar`. + +**JSON Schema Example:** +```python +from openai import OpenAI +from pydantic import BaseModel +from enum import Enum +import json + +client = OpenAI(base_url="http://localhost:9000/v1", api_key="EMPTY") + +class CarType(str, Enum): + sedan = "sedan" + suv = "SUV" + truck = "Truck" + +class CarDescription(BaseModel): + brand: str + model: str + car_type: CarType + +json_schema = CarDescription.model_json_schema() + +completion = client.chat.completions.create( + model="ensemble", # or "tensorrt_llm_bls" + messages=[{"role": "user", "content": "Generate a 90's iconic car"}], + max_tokens=100, + extra_body={ + "guided_decoding_guide_type": "json_schema", + "guided_decoding_guide": json.dumps(json_schema), + }, +) +print(completion.choices[0].message.content) +# Output: { "brand": "Ford", "model": "Mustang", "car_type": "SUV" } +``` + +**Regex Pattern Example:** +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:9000/v1", api_key="EMPTY") + +completion = client.chat.completions.create( + model="ensemble", # or "tensorrt_llm_bls" + messages=[{"role": "user", "content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com"}], + extra_body={ + "guided_decoding_guide_type": "regex", + "guided_decoding_guide": "\\w+@\\w+\\.com", + }, +) +print(completion.choices[0].message.content) +# Output: alan_turing@enigma.com +``` + +**EBNF Grammar Example:** +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:9000/v1", api_key="EMPTY") + +ebnf_grammar_str = """root ::= (expr "=" term)+ +expr ::= term ([-+*/] term)* +term ::= num | "(" expr ")" +num ::= [0-9]+""" + +completion = client.chat.completions.create( + model="ensemble", # or "tensorrt_llm_bls" + messages=[{"role": "user", "content": "Generate mathematical equations with basic arithmetic operations"}], + max_tokens=100, + extra_body={ + "guided_decoding_guide_type": "ebnf_grammar", + "guided_decoding_guide": ebnf_grammar_str, + }, +) +print(completion.choices[0].message.content) +# Output: 10+5=15 +``` + ## KServe Frontends To support serving requests through both the OpenAI-Compatible and diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 53f1af44d6..ffdbdb8754 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -89,6 +89,23 @@ def _create_vllm_inference_request( ) sampling_parameters = json.dumps(sampling_parameters_json) + elif request.guided_decoding_guide_type is not None: + from vllm.sampling_params import GuidedDecodingParams + + sampling_parameters_json = json.loads(sampling_parameters) + sampling_parameters_json["guided_decoding"] = json.dumps( + asdict( + GuidedDecodingParams.from_optional( + **{ + request.guided_decoding_guide_type: request.guided_decoding_guide + } + ) + ) + ) + sampling_parameters_json.pop("guided_decoding_guide_type", None) + sampling_parameters_json.pop("guided_decoding_guide", None) + sampling_parameters = json.dumps(sampling_parameters_json) + exclude_input_in_output = True echo = getattr(request, "echo", None) if echo is not None: @@ -137,6 +154,9 @@ def _create_trtllm_inference_request( if guided_json is not None: inputs["guided_decoding_guide_type"] = [["json_schema"]] inputs["guided_decoding_guide"] = [[guided_json]] + elif request.guided_decoding_guide_type is not None: + inputs["guided_decoding_guide_type"] = [[request.guided_decoding_guide_type]] + inputs["guided_decoding_guide"] = [[request.guided_decoding_guide]] # FIXME: TRT-LLM doesn't currently support runtime changes of 'echo' and it # is configured at model load time, so we don't handle it here for now. return model.create_request(inputs=inputs) diff --git a/python/openai/openai_frontend/schemas/openai.py b/python/openai/openai_frontend/schemas/openai.py index a2438e8394..cfb6e001fa 100644 --- a/python/openai/openai_frontend/schemas/openai.py +++ b/python/openai/openai_frontend/schemas/openai.py @@ -153,6 +153,14 @@ class CreateCompletionRequest(BaseModel): description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", examples=["user-1234"], ) + guided_decoding_guide_type: Optional[str] = Field( + None, + description="The type of guided decoding to use.\n", + ) + guided_decoding_guide: Optional[str] = Field( + None, + description="The guide to use for guided decoding.\n", + ) class FinishReason(Enum): @@ -917,6 +925,14 @@ class CreateChatCompletionRequest(BaseModel): max_length=128, min_length=1, ) + guided_decoding_guide_type: Optional[str] = Field( + None, + description="The type of guided decoding to use.\n", + ) + guided_decoding_guide: Optional[Union[str, List[str], Dict[str, Any]]] = Field( + None, + description="The guide to use for guided decoding.\n", + ) # Additional Aliases for Convenience