Skip to content

Commit 3c391e4

Browse files
committed
fix glm
1 parent 2335760 commit 3c391e4

File tree

3 files changed

+274
-61
lines changed

3 files changed

+274
-61
lines changed

gpt_server/model_handler/prompts.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,131 @@
11
from typing import Optional
2-
from lmdeploy.model import MODELS, Qwen7BChat
2+
from lmdeploy.model import MODELS, Qwen7BChat, ChatGLM3, get_text
33
import json
44

55

6+
@MODELS.register_module(name="glm4", force=True)
7+
class Glm4Chat(ChatGLM3):
8+
"""Chat template of glm-4 model."""
9+
10+
def __init__(
11+
self,
12+
system="<|system|>\n",
13+
user="<|user|>\n",
14+
assistant="<|assistant|>\n",
15+
separator="\n",
16+
tools="""\n\n你可以使用以下工具提供适当的答复和支持。\n\n# 可用工具\n\n在<tools></tools> XML标签中提供了function的签名(即函数的结构信息):\n<tools>""",
17+
eotools="""\n</tools>
18+
Use the following format:
19+
20+
Question: the input question you must answer
21+
Thought: you should always think about what to do
22+
Action: the action to take, should be one of [{tool_names}]
23+
Action Input: the input to the action
24+
Observation: the result of the action
25+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
26+
Thought: I now know the final answer
27+
Final Answer: the final answer to the original input question
28+
29+
Begin!
30+
""",
31+
stop_words=["<|user|>", "<|endoftext|>", "<|observation|>"],
32+
meta_instruction="你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。",
33+
**kwargs,
34+
):
35+
super().__init__(
36+
system=system,
37+
user=user,
38+
assistant=assistant,
39+
stop_words=stop_words,
40+
separator=separator,
41+
meta_instruction=meta_instruction,
42+
**kwargs,
43+
)
44+
self.start = "[gMASK]<sop>"
45+
self.tools = tools
46+
self.eotools = eotools
47+
48+
@classmethod
49+
def match(cls, model_path: str) -> Optional[str]:
50+
"""Return the model_name that was registered to MODELS.
51+
52+
Args:
53+
model_path (str): the model path used for matching.
54+
"""
55+
path = model_path.lower()
56+
if "glm-4" in path:
57+
return "glm4"
58+
59+
def messages2prompt(self, messages, sequence_start=True, tools=None, **kwargs):
60+
"""Return the prompt that is concatenated with other elements in the
61+
chat template.
62+
63+
Args:
64+
messages (str | List): user's input prompt
65+
Returns:
66+
str: the concatenated prompt
67+
"""
68+
if isinstance(messages, str):
69+
return self.get_prompt(messages, sequence_start)
70+
return self.start + self.messages2prompt_base(
71+
messages, sequence_start, tools=tools, **kwargs
72+
)
73+
74+
def messages2prompt_base(self, messages, sequence_start=True, tools=None, **kwargs):
75+
"""Return the prompt that is concatenated with other elements in the
76+
chat template.
77+
78+
Args:
79+
messages (str | List): user's input prompt
80+
Returns:
81+
str: the concatenated prompt
82+
"""
83+
84+
if isinstance(messages, str):
85+
return self.get_prompt(messages, sequence_start)
86+
box_map = dict(
87+
user=self.user, assistant=self.assistant, system=self.system, tool=self.tool
88+
)
89+
eox_map = dict(
90+
user=self.eoh,
91+
assistant=self.eoa + self.separator,
92+
system=self.eosys,
93+
tool=self.eotool,
94+
)
95+
ret = ""
96+
if self.meta_instruction is not None and sequence_start:
97+
if len(messages) and messages[0]["role"] != "system":
98+
ret += f"{self.system}{self.meta_instruction}{self.eosys}"
99+
tool_prompt = ""
100+
if tools is not None and len(tools) > 0:
101+
tool_names = []
102+
for tool in tools:
103+
tool_names.append(tool["function"]["name"])
104+
tool_names = ",".join(tool_names)
105+
self.eotools = self.eotools.format(tool_names=tool_names)
106+
for tool in tools:
107+
tool_prompt += self.separator
108+
tool_prompt += f'{{"type": "function", "function": {json.dumps(tool, ensure_ascii=False)}}}'
109+
if len(messages) and messages[0]["role"] == "system":
110+
ret += f"{self.system}{messages[0]['content']}{self.tools}{tool_prompt}{self.eotools}{self.eosys}"
111+
messages.pop(0)
112+
else:
113+
ret += f"{self.system}{self.meta_instruction}{self.tools}{tool_prompt}{self.eotools}{self.eosys}"
114+
115+
for message in messages:
116+
role = message["role"]
117+
content = get_text(message["content"])
118+
ret += f"{box_map[role]}{content}{eox_map[role]}"
119+
if (
120+
len(messages)
121+
and messages[-1]["role"] == "assistant"
122+
and len(eox_map["assistant"]) > 0
123+
):
124+
return ret[: -len(eox_map["assistant"])] # prefix of response
125+
ret += f"{self.assistant}"
126+
return ret
127+
128+
6129
@MODELS.register_module(name="qwen2_5")
7130
class Qwen2d5Chat(Qwen7BChat):
8131
"""Chat template for Qwen2.5-Instruct series."""
@@ -124,3 +247,52 @@ def match(cls, model_path: str) -> Optional[str]:
124247
lower_path = model_path.lower()
125248
if "qwen2.5" in lower_path or "qwen2_5" in lower_path:
126249
return "qwen2d5"
250+
251+
252+
if __name__ == "__main__":
253+
chat_template = MODELS.module_dict["glm4"]()
254+
messages = [
255+
{"role": "system", "content": "我的Qwen "},
256+
{"role": "user", "content": "你是谁 "},
257+
]
258+
tools = [
259+
{
260+
"type": "function",
261+
"function": {
262+
"name": "get_weather",
263+
"description": "Get the current weather in a given location",
264+
"parameters": {
265+
"type": "object",
266+
"properties": {
267+
"location": {
268+
"type": "string",
269+
"description": "City and state, e.g., 'San Francisco, CA'",
270+
},
271+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
272+
},
273+
"required": ["location"],
274+
},
275+
},
276+
},
277+
{
278+
"type": "function",
279+
"function": {
280+
"name": "get_weather2",
281+
"description": "Get the current weather in a given location",
282+
"parameters": {
283+
"type": "object",
284+
"properties": {
285+
"location": {
286+
"type": "string",
287+
"description": "City and state, e.g., 'San Francisco, CA'",
288+
},
289+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
290+
},
291+
"required": ["location"],
292+
},
293+
},
294+
},
295+
]
296+
# tools = None
297+
promt = chat_template.messages2prompt(messages, True, tools)
298+
print(promt)

gpt_server/model_handler/tool_parser.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import uuid
23
from loguru import logger
34
import re
45
from typing import Dict, List, Literal, Sequence, Union, Optional
@@ -41,6 +42,52 @@ class ExtractedToolCallInformation(BaseModel):
4142
content: Optional[str] = None
4243

4344

45+
@ToolParserManager.register_module(["glm"])
46+
class GLMToolParser(ToolParser):
47+
48+
def __init__(self, tokenizer: object = None):
49+
super().__init__(tokenizer)
50+
self.position = 0
51+
52+
def get_argments(self, obj):
53+
if "parameters" in obj:
54+
return obj.get("parameters")
55+
elif "arguments" in obj:
56+
return obj.get("arguments")
57+
return None
58+
59+
def extract_tool_calls(
60+
self,
61+
model_output: str,
62+
tools,
63+
) -> ExtractedToolCallInformation:
64+
text = model_output
65+
try:
66+
67+
i = text.rfind("Action:")
68+
j = text.rfind("Action Input:")
69+
name = text[i + len("Action:") : j].strip().strip(".")
70+
if "Observation" in model_output:
71+
k = text.rfind("Observation")
72+
arguments = text[j + len("Action Input:") : k].strip()
73+
else:
74+
arguments = text[j + len("Action Input:") :].strip()
75+
tool_calls = []
76+
tool_calls.append(
77+
ToolCall(function=FunctionCall(name=name, arguments=arguments))
78+
)
79+
except Exception:
80+
return ExtractedToolCallInformation(
81+
tools_called=False, tool_calls=[], content=text
82+
)
83+
84+
return ExtractedToolCallInformation(
85+
tools_called=True,
86+
tool_calls=tool_calls,
87+
content=text if len(text) > 0 else "",
88+
)
89+
90+
4491
@ToolParserManager.register_module(["qwen2_5"])
4592
class Qwen2d5ToolParser(ToolParser):
4693

@@ -271,11 +318,18 @@ def tool_parser(full_text: str, tool_parser: ToolParser, tools, ret):
271318
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
272319
tool_calls = [i.model_dump() for i in tool_calls]
273320
if tools and tools_called: # 如果传入tools
274-
logger.debug(f"工具解析成功, tool_calls: {tool_calls}")
321+
logger.info(f"工具解析成功, tool_calls: {tool_calls}")
275322
ret["text"] = ""
276323
ret["tool_calls"] = tool_calls
277324
ret["finish_reason"] = "tool_calls"
278325
return json.dumps(ret).encode() + b"\0"
279326
else:
280327
ret["text"] = ""
281328
return json.dumps(ret).encode() + b"\0"
329+
330+
331+
if __name__ == "__main__":
332+
full_text = """Action: get_weather
333+
Action Input: {"location": "Nanjing", "unit": "celsius"}"""
334+
tool_parser2 = ToolParserManager.module_dict["glm"]()
335+
tool_parser(full_text=full_text, tool_parser=tool_parser2, tools=True, ret={})

0 commit comments

Comments
 (0)