|
1 | 1 | from typing import Optional
|
2 |
| -from lmdeploy.model import MODELS, Qwen7BChat |
| 2 | +from lmdeploy.model import MODELS, Qwen7BChat, ChatGLM3, get_text |
3 | 3 | import json
|
4 | 4 |
|
5 | 5 |
|
| 6 | +@MODELS.register_module(name="glm4", force=True) |
| 7 | +class Glm4Chat(ChatGLM3): |
| 8 | + """Chat template of glm-4 model.""" |
| 9 | + |
| 10 | + def __init__( |
| 11 | + self, |
| 12 | + system="<|system|>\n", |
| 13 | + user="<|user|>\n", |
| 14 | + assistant="<|assistant|>\n", |
| 15 | + separator="\n", |
| 16 | + tools="""\n\n你可以使用以下工具提供适当的答复和支持。\n\n# 可用工具\n\n在<tools></tools> XML标签中提供了function的签名(即函数的结构信息):\n<tools>""", |
| 17 | + eotools="""\n</tools> |
| 18 | +Use the following format: |
| 19 | +
|
| 20 | +Question: the input question you must answer |
| 21 | +Thought: you should always think about what to do |
| 22 | +Action: the action to take, should be one of [{tool_names}] |
| 23 | +Action Input: the input to the action |
| 24 | +Observation: the result of the action |
| 25 | +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) |
| 26 | +Thought: I now know the final answer |
| 27 | +Final Answer: the final answer to the original input question |
| 28 | +
|
| 29 | +Begin! |
| 30 | +""", |
| 31 | + stop_words=["<|user|>", "<|endoftext|>", "<|observation|>"], |
| 32 | + meta_instruction="你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。", |
| 33 | + **kwargs, |
| 34 | + ): |
| 35 | + super().__init__( |
| 36 | + system=system, |
| 37 | + user=user, |
| 38 | + assistant=assistant, |
| 39 | + stop_words=stop_words, |
| 40 | + separator=separator, |
| 41 | + meta_instruction=meta_instruction, |
| 42 | + **kwargs, |
| 43 | + ) |
| 44 | + self.start = "[gMASK]<sop>" |
| 45 | + self.tools = tools |
| 46 | + self.eotools = eotools |
| 47 | + |
| 48 | + @classmethod |
| 49 | + def match(cls, model_path: str) -> Optional[str]: |
| 50 | + """Return the model_name that was registered to MODELS. |
| 51 | +
|
| 52 | + Args: |
| 53 | + model_path (str): the model path used for matching. |
| 54 | + """ |
| 55 | + path = model_path.lower() |
| 56 | + if "glm-4" in path: |
| 57 | + return "glm4" |
| 58 | + |
| 59 | + def messages2prompt(self, messages, sequence_start=True, tools=None, **kwargs): |
| 60 | + """Return the prompt that is concatenated with other elements in the |
| 61 | + chat template. |
| 62 | +
|
| 63 | + Args: |
| 64 | + messages (str | List): user's input prompt |
| 65 | + Returns: |
| 66 | + str: the concatenated prompt |
| 67 | + """ |
| 68 | + if isinstance(messages, str): |
| 69 | + return self.get_prompt(messages, sequence_start) |
| 70 | + return self.start + self.messages2prompt_base( |
| 71 | + messages, sequence_start, tools=tools, **kwargs |
| 72 | + ) |
| 73 | + |
| 74 | + def messages2prompt_base(self, messages, sequence_start=True, tools=None, **kwargs): |
| 75 | + """Return the prompt that is concatenated with other elements in the |
| 76 | + chat template. |
| 77 | +
|
| 78 | + Args: |
| 79 | + messages (str | List): user's input prompt |
| 80 | + Returns: |
| 81 | + str: the concatenated prompt |
| 82 | + """ |
| 83 | + |
| 84 | + if isinstance(messages, str): |
| 85 | + return self.get_prompt(messages, sequence_start) |
| 86 | + box_map = dict( |
| 87 | + user=self.user, assistant=self.assistant, system=self.system, tool=self.tool |
| 88 | + ) |
| 89 | + eox_map = dict( |
| 90 | + user=self.eoh, |
| 91 | + assistant=self.eoa + self.separator, |
| 92 | + system=self.eosys, |
| 93 | + tool=self.eotool, |
| 94 | + ) |
| 95 | + ret = "" |
| 96 | + if self.meta_instruction is not None and sequence_start: |
| 97 | + if len(messages) and messages[0]["role"] != "system": |
| 98 | + ret += f"{self.system}{self.meta_instruction}{self.eosys}" |
| 99 | + tool_prompt = "" |
| 100 | + if tools is not None and len(tools) > 0: |
| 101 | + tool_names = [] |
| 102 | + for tool in tools: |
| 103 | + tool_names.append(tool["function"]["name"]) |
| 104 | + tool_names = ",".join(tool_names) |
| 105 | + self.eotools = self.eotools.format(tool_names=tool_names) |
| 106 | + for tool in tools: |
| 107 | + tool_prompt += self.separator |
| 108 | + tool_prompt += f'{{"type": "function", "function": {json.dumps(tool, ensure_ascii=False)}}}' |
| 109 | + if len(messages) and messages[0]["role"] == "system": |
| 110 | + ret += f"{self.system}{messages[0]['content']}{self.tools}{tool_prompt}{self.eotools}{self.eosys}" |
| 111 | + messages.pop(0) |
| 112 | + else: |
| 113 | + ret += f"{self.system}{self.meta_instruction}{self.tools}{tool_prompt}{self.eotools}{self.eosys}" |
| 114 | + |
| 115 | + for message in messages: |
| 116 | + role = message["role"] |
| 117 | + content = get_text(message["content"]) |
| 118 | + ret += f"{box_map[role]}{content}{eox_map[role]}" |
| 119 | + if ( |
| 120 | + len(messages) |
| 121 | + and messages[-1]["role"] == "assistant" |
| 122 | + and len(eox_map["assistant"]) > 0 |
| 123 | + ): |
| 124 | + return ret[: -len(eox_map["assistant"])] # prefix of response |
| 125 | + ret += f"{self.assistant}" |
| 126 | + return ret |
| 127 | + |
| 128 | + |
6 | 129 | @MODELS.register_module(name="qwen2_5")
|
7 | 130 | class Qwen2d5Chat(Qwen7BChat):
|
8 | 131 | """Chat template for Qwen2.5-Instruct series."""
|
@@ -124,3 +247,52 @@ def match(cls, model_path: str) -> Optional[str]:
|
124 | 247 | lower_path = model_path.lower()
|
125 | 248 | if "qwen2.5" in lower_path or "qwen2_5" in lower_path:
|
126 | 249 | return "qwen2d5"
|
| 250 | + |
| 251 | + |
| 252 | +if __name__ == "__main__": |
| 253 | + chat_template = MODELS.module_dict["glm4"]() |
| 254 | + messages = [ |
| 255 | + {"role": "system", "content": "我的Qwen "}, |
| 256 | + {"role": "user", "content": "你是谁 "}, |
| 257 | + ] |
| 258 | + tools = [ |
| 259 | + { |
| 260 | + "type": "function", |
| 261 | + "function": { |
| 262 | + "name": "get_weather", |
| 263 | + "description": "Get the current weather in a given location", |
| 264 | + "parameters": { |
| 265 | + "type": "object", |
| 266 | + "properties": { |
| 267 | + "location": { |
| 268 | + "type": "string", |
| 269 | + "description": "City and state, e.g., 'San Francisco, CA'", |
| 270 | + }, |
| 271 | + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, |
| 272 | + }, |
| 273 | + "required": ["location"], |
| 274 | + }, |
| 275 | + }, |
| 276 | + }, |
| 277 | + { |
| 278 | + "type": "function", |
| 279 | + "function": { |
| 280 | + "name": "get_weather2", |
| 281 | + "description": "Get the current weather in a given location", |
| 282 | + "parameters": { |
| 283 | + "type": "object", |
| 284 | + "properties": { |
| 285 | + "location": { |
| 286 | + "type": "string", |
| 287 | + "description": "City and state, e.g., 'San Francisco, CA'", |
| 288 | + }, |
| 289 | + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, |
| 290 | + }, |
| 291 | + "required": ["location"], |
| 292 | + }, |
| 293 | + }, |
| 294 | + }, |
| 295 | + ] |
| 296 | + # tools = None |
| 297 | + promt = chat_template.messages2prompt(messages, True, tools) |
| 298 | + print(promt) |
0 commit comments