Skip to content

Commit aa7b97e

Browse files
committed
add system react
1 parent e31dd4a commit aa7b97e

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[![Docker pulls][docker-pulls]][docker-pulls]
1010
[![CI Status][ci-shield]][ci-url]
1111
[![issue resolution][closed-issues-shield]][closed-issues-url]
12-
[![open issues][open-issues-shield]][open-issues-url]
12+
1313
</div>
1414

1515
本项目依托fastchat的基础能力来提供**openai server**的能力.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Any, Dict, List, Tuple, Union, Optional
2+
import json
3+
import uuid
4+
5+
from gpt_server.model_handler.react.prompt import (
6+
GLM4_TOOL_PROMPT,
7+
TOOL_SUFFIX_PROMPT,
8+
)
9+
10+
11+
def system_tool_formatter(
12+
tools: List[Dict[str, Any]], tool_choice_info: Optional[dict] = None
13+
) -> str:
14+
tool_text = "\n"
15+
tool_names = []
16+
for tool in tools:
17+
tool = tool["function"]
18+
tool_name = tool["name"]
19+
tool_text += f"## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{TOOL_SUFFIX_PROMPT}\n\n"
20+
tool_names.append(tool_name)
21+
return GLM4_TOOL_PROMPT.format(
22+
tool_text=tool_text, tool_names=", ".join(tool_names)
23+
).strip()
24+
25+
26+
def system_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
27+
i = content.rfind("Action:")
28+
j = content.rfind("Action Input:")
29+
tool_name = content[i + len("Action:") : j].strip().strip(".")
30+
tool_input = content[j + len("Action Input:") :].strip()
31+
try:
32+
tool_input_obj = json.loads(tool_input)
33+
except json.JSONDecodeError:
34+
return content
35+
tool_calls = []
36+
tool_call = {
37+
"index": 0,
38+
"id": "call_{}".format(uuid.uuid4().hex),
39+
"function": {"name": tool_name, "arguments": tool_input},
40+
}
41+
tool_calls.append(tool_call)
42+
43+
return tool_calls
44+
45+
46+
if __name__ == "__main__":
47+
import json
48+
49+
tools_str = """[{'type': 'function', 'function': {'name': 'track', 'description': '追踪指定股票的实时价格', 'parameters': {'type': 'object', 'properties': {'symbol': {'description': '需要追踪的股票代码', 'type': 'integer'}}, 'required': ['symbol']}}}, {'type': 'function', 'function': {'name': 'text-to-speech', 'description': '将文本转换为语音', 'parameters': {'type': 'object', 'properties': {'text': {'description': '需要转换成语音的文本', 'type': 'string'}, 'voice': {'description': '要使用的语音类型(男声、女声等', 'default': '男声', 'type': 'string'}, 'speed': {'description': '语音的速度(快、中等、慢等', 'default': '中等', 'type': 'string'}}, 'required': ['text']}}}]"""
50+
tools_str = tools_str.replace("'", '"')
51+
tools = json.loads(tools_str)
52+
53+
res = system_tool_formatter(tools=tools)
54+
print(res)
55+
print()
56+
out = 'multiply\n{"first_int": 8, "second_int": 9}'
57+
r = system_tool_extractor(out)
58+
print(r)

gpt_server/model_handler/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from gpt_server.model_handler.react.qwen_react import qwen_tool_formatter
22
from gpt_server.model_handler.react.chatglm_react import glm4_tool_formatter
3+
from gpt_server.model_handler.react.system_react import system_tool_formatter
34
from loguru import logger
45

56

@@ -32,6 +33,10 @@ def add_tools2messages(params: dict, model_adapter: str = "default"):
3233
system_content = glm4_tool_formatter(
3334
tools=params.get("tools"), tool_choice_info=tool_choice_info
3435
)
36+
else:
37+
system_content = system_tool_formatter(
38+
tools=params.get("tools"), tool_choice_info=tool_choice_info
39+
)
3540

3641
if messages[0]["role"] != "system":
3742
messages.insert(0, {"role": "system", "content": system_content})

0 commit comments

Comments
 (0)