Skip to content

Commit 292dfd9

Browse files
feat: add the tool calling to the openai frontend (#8134)
Add the tool calling parsers implementation to openai frontend, the available parsers are llama3 and mistral. Most of the implementation is from the vllm. A user could use the --tool-call-parser arguments to specify the tool parser. Add the --chat-template {chat template file path} argument to allow the user use the customized template to better tune the prompt for tool calling. Add the guided decoding backend integration with the tool calling to enable the named tool calling and required tool calling functionalities. Please check more detail in the change of README.md All changes in python/openai/openai_frontend/engine/utils/tool_call_parsers are from the vLLM with some minor compatibility changes.
1 parent 5eb09ce commit 292dfd9

24 files changed

+2495
-87
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ test_results.txt
1010
artifacts
1111
cprofile
1212
*.prof
13+
.venv
14+
**/.venv
1315

1416
# Test exclusions
1517
qa/L0_openai/openai

python/openai/README.md

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,3 +405,253 @@ For more information on the `tritonfrontend` python bindings, see the docs
405405
- Set the following environment variable: `export TRTLLM_ORCHESTRATOR=1`
406406
- [ ] TensorRT-LLM ([Leader Mode](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/README.md#leader-mode))
407407
- Not currently supported
408+
409+
## Tool Calling
410+
411+
The OpenAI frontend supports `tools` and `tool_choice` in the `v1/chat/completions` API. Please refer to the OpenAI API reference for more details about these parameters:
412+
[tools](https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools),
413+
[tool_choice](https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice)
414+
415+
To enable the tool-calling feature, add the `--tool-call-parser {parser_name}` flag when starting the server. The two available parsers are `llama3` and `mistral`.
416+
The `llama3` parser supports tool-calling features for LLaMA 3.1, 3.2, and 3.3 models, while the `mistral` parser supports tool-calling features for the Mistral Instruct model.
417+
418+
Example for launching the OpenAI frontend with a tool call parser:
419+
```
420+
python3 openai_frontend/main.py \
421+
--model-repository tests/vllm_models \
422+
--tokenizer meta-llama/Meta-Llama-3.1-8B-Instruct \
423+
--tool-call-parser llama3
424+
```
425+
426+
Example for making a tool calling request:
427+
428+
```python
429+
import json
430+
from openai import OpenAI
431+
432+
433+
def get_current_weather(city: str, state: str, unit: "str"):
434+
return (
435+
"The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
436+
"partly cloudly, with highs in the 90's."
437+
)
438+
439+
available_tools = {"get_current_weather": get_current_weather}
440+
441+
openai_api_key = "EMPTY"
442+
openai_api_base = "http://localhost:9000/v1"
443+
444+
client = OpenAI(
445+
api_key=openai_api_key,
446+
base_url=openai_api_base,
447+
)
448+
449+
model = "llama-3.1-8b-instruct" # change this to the model in the repository
450+
451+
tools = [
452+
{
453+
"type": "function",
454+
"function": {
455+
"name": "get_current_weather",
456+
"description": "Get the current weather in a given location",
457+
"parameters": {
458+
"type": "object",
459+
"properties": {
460+
"city": {
461+
"type": "string",
462+
"description": "The city to find the weather for, e.g. 'San Francisco'",
463+
},
464+
"state": {
465+
"type": "string",
466+
"description": "the two-letter abbreviation for the state that the city is"
467+
" in, e.g. 'CA' which would mean 'California'",
468+
},
469+
"unit": {
470+
"type": "string",
471+
"description": "The unit to fetch the temperature in",
472+
"enum": ["celsius", "fahrenheit"],
473+
},
474+
},
475+
"required": ["city", "state", "unit"],
476+
},
477+
},
478+
}
479+
]
480+
481+
messages = [
482+
{
483+
"role": "system",
484+
"content": "You're a helpful assistant! Answer the users question best you can.",
485+
},
486+
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"},
487+
]
488+
489+
tool_calls = client.chat.completions.create(
490+
messages=messages, model=model, tools=tools, max_tokens=128
491+
)
492+
function_name = tool_calls.choices[0].message.tool_calls[0].function.name
493+
function_arguments = tool_calls.choices[0].message.tool_calls[0].function.arguments
494+
495+
print(f"function name: " f"{function_name}")
496+
print(f"function arguments: {function_arguments}")
497+
print(f"tool calling result: {available_tools[function_name](**json.loads(function_arguments))}")
498+
```
499+
500+
Example output:
501+
```
502+
function name: get_current_weather
503+
function arguments: {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
504+
tool calling result: The weather in Dallas, Texas is 85 degrees fahrenheit. It is partly cloudly, with highs in the 90's.
505+
```
506+
507+
<!-- TODO: Remove this warning when the openai api supports the max_completion_tokens instead of max_tokens -->
508+
> [!WARNING]
509+
> When using LangChain to call the `v1/chat/completions` endpoint, you might encounter an exception related to `max_completion_tokens` if you have specified `max_tokens` in the request.
510+
>
511+
> Example: `openai.BadRequestError: Error code: 400 - {'object': 'error', 'message': "[{'type': 'extra_forbidden', 'loc': ('body', 'max_completion_tokens'), 'msg': 'Extra inputs are not permitted', 'input': 800}]", 'type': 'BadRequestError', 'param': None, 'code': 400}`
512+
>
513+
> This issue is due to an incompatibility between Triton's OpenAI API frontend and the latest OpenAI API. We are actively working to address this gap. A workaround is adding the `max_tokens` into the `model_kwargs` of the LangChain OpenAI request.
514+
>
515+
> Example:
516+
```python
517+
from langchain.llms import OpenAI
518+
519+
llm = OpenAI(
520+
model_name="llama-3.1-8b-instruct",
521+
temperature=0.0,
522+
model_kwargs={
523+
"max_tokens": 4096
524+
}
525+
)
526+
527+
response = llm("Write a short poem about a sunset.")
528+
print(response)
529+
530+
```
531+
532+
#### Named Tool Calling
533+
534+
The OpenAI frontend supports named function calling, utilizing guided decoding in the vLLM and TensorRT-LLM backends. Users can specify one of the tools in `tool_choice` to force the model to select a specific tool for function calling.
535+
536+
> [!NOTE]
537+
> The latest release of TensorRT-LLM (v0.18.0) does not yet support guided decoding. To enable this feature, use a build from the main branch of TensorRT-LLM.
538+
> For instructions on enabling guided decoding in the TensorRT-LLM backend, please refer to [this guide](https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/guided_decoding.md)
539+
540+
Example for making a named tool calling request:
541+
542+
```python
543+
import json
544+
from openai import OpenAI
545+
546+
547+
def get_current_weather(city: str, state: str, unit: "str"):
548+
return (
549+
"The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
550+
"partly cloudly, with highs in the 90's."
551+
)
552+
553+
def get_n_day_weather_forecast(city: str, state: str, unit: str, num_days: int):
554+
return (
555+
f"The weather in Dallas, Texas is 85 degrees fahrenheit in next {num_days} days."
556+
)
557+
558+
available_tools = {"get_current_weather": get_current_weather,
559+
"get_n_day_weather_forecast": get_n_day_weather_forecast}
560+
561+
openai_api_key = "EMPTY"
562+
openai_api_base = "http://localhost:9000/v1"
563+
client = OpenAI(
564+
api_key=openai_api_key,
565+
base_url=openai_api_base,
566+
)
567+
model = "llama-3.1-8b-instruct" # change this to the model in the repository
568+
tools = [
569+
{
570+
"type": "function",
571+
"function": {
572+
"name": "get_current_weather",
573+
"description": "Get the current weather in a given location",
574+
"parameters": {
575+
"type": "object",
576+
"properties": {
577+
"city": {
578+
"type": "string",
579+
"description": "The city to find the weather for, e.g. 'San Francisco'",
580+
},
581+
"state": {
582+
"type": "string",
583+
"description": "the two-letter abbreviation for the state that the city is"
584+
" in, e.g. 'CA' which would mean 'California'",
585+
},
586+
"unit": {
587+
"type": "string",
588+
"description": "The unit to fetch the temperature in",
589+
"enum": ["celsius", "fahrenheit"],
590+
},
591+
},
592+
"required": ["city", "state", "unit"],
593+
},
594+
},
595+
},
596+
{
597+
"type": "function",
598+
"function": {
599+
"name": "get_n_day_weather_forecast",
600+
"description": "Get an N-day weather forecast",
601+
"parameters": {
602+
"type": "object",
603+
"properties": {
604+
"city": {
605+
"type": "string",
606+
"description": "The city to find the weather for, "
607+
"e.g. 'San Francisco'",
608+
},
609+
"state": {
610+
"type": "string",
611+
"description": "must the two-letter abbreviation for the state "
612+
"that the city is in, e.g. 'CA' which would "
613+
"mean 'California'",
614+
},
615+
"unit": {
616+
"type": "string",
617+
"description": "The unit to fetch the temperature in",
618+
"enum": ["celsius", "fahrenheit"],
619+
},
620+
"num_days": {
621+
"type": "integer",
622+
"description": "The number of days to forecast",
623+
},
624+
},
625+
"required": ["city", "state", "unit", "num_days"],
626+
},
627+
},
628+
}
629+
]
630+
631+
tool_choice = {"function": {"name": "get_n_day_weather_forecast"}, "type": "function"}
632+
633+
messages = [
634+
{
635+
"role": "system",
636+
"content": "You're a helpful assistant! Answer the users question best you can.",
637+
},
638+
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"},
639+
]
640+
641+
tool_calls = client.chat.completions.create(
642+
messages=messages, model=model, tools=tools, tool_choice=tool_choice, max_tokens=128
643+
)
644+
function_name = tool_calls.choices[0].message.tool_calls[0].function.name
645+
function_arguments = tool_calls.choices[0].message.tool_calls[0].function.arguments
646+
647+
print(f"function name: "{function_name}")
648+
print(f"function arguments: {function_arguments}")
649+
print(f"tool calling result: {available_tools[function_name](**json.loads(function_arguments))}")
650+
```
651+
652+
Example output:
653+
```
654+
function name: get_n_day_weather_forecast
655+
function arguments: {"city": "Dallas", "state": "TX", "unit": "fahrenheit", num_days: 1}
656+
tool calling result: The weather in Dallas, Texas is 85 degrees fahrenheit in next 1 days.
657+
```

0 commit comments

Comments
 (0)