Skip to content

Commit 1dba906

Browse files
committed
支持了 phi-4
1 parent 8e40e32 commit 1dba906

File tree

3 files changed

+96
-2
lines changed

3 files changed

+96
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
## 更新信息
4343

4444
```plaintext
45+
2024-12-14 支持了 phi-4
4546
2024-12-7 支持了 /v1/rerank 接口
4647
2024-12-1 支持了 QWQ-32B-Preview
4748
2024-10-15 支持了 Qwen2-VL
@@ -263,7 +264,7 @@ streamlit run server_ui.py
263264
| Llama-3 |llama|||||
264265
| Baichuan-2 |baichuan|||||
265266
| QWQ-32B-Preview |qwen|||||
266-
267+
| Phi-4 |phi||| × | × |
267268
### **VLM** (视觉大模型榜单 https://rank.opencompass.org.cn/leaderboard-multimodal)
268269

269270
| Models / BackEnd |model_type| HF | vllm | LMDeploy TurboMind | LMDeploy PyTorch |

gpt_server/model_backend/lmdeploy_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
107107
response_format=params["response_format"],
108108
)
109109
logger.info(f"request_id {int(request_id)}")
110-
messages = prompt or messages # TODO 可能影响推理性能
110+
if params.get("tools", None):
111+
messages = prompt or messages # 解决lmdeploy 的提示模板不支持 tools
111112
results_generator = self.async_engine.generate(
112113
messages=messages, session_id=int(request_id), gen_config=gen_config
113114
)

gpt_server/model_worker/phi.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import json
2+
from typing import List
3+
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
4+
from loguru import logger
5+
import torch
6+
7+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
8+
9+
10+
class PhiWorker(ModelWorkerBase):
11+
def __init__(
12+
self,
13+
controller_addr: str,
14+
worker_addr: str,
15+
worker_id: str,
16+
model_path: str,
17+
model_names: List[str],
18+
limit_worker_concurrency: int,
19+
conv_template: str = None, # type: ignore
20+
):
21+
super().__init__(
22+
controller_addr,
23+
worker_addr,
24+
worker_id,
25+
model_path,
26+
model_names,
27+
limit_worker_concurrency,
28+
conv_template,
29+
model_type="AutoModelForCausalLM",
30+
)
31+
# from tokenizer_config.json
32+
self.stop_words_ids = [
33+
100257, # eos
34+
100265, # eos
35+
]
36+
37+
self.stop = [
38+
self.tokenizer.decode(skip_word) for skip_word in self.stop_words_ids
39+
]
40+
logger.info(f"{model_names[0]} 停用词: {self.stop}")
41+
42+
async def generate_stream_gate(self, params):
43+
self.call_ct += 1
44+
logger.info(f"params {params}")
45+
logger.info(f"worker_id: {self.worker_id}")
46+
try:
47+
messages = params["messages"]
48+
if isinstance(messages, list):
49+
task = "chat"
50+
elif isinstance(messages, str):
51+
task = "completion"
52+
if task == "chat":
53+
# 暂时保留,用于特殊情况的处理
54+
text = self.tokenizer.apply_chat_template(
55+
conversation=messages,
56+
tokenize=False,
57+
add_generation_prompt=True,
58+
)
59+
elif task == "completion":
60+
text = messages
61+
62+
input_ids = self.tokenizer([text], return_tensors="pt").input_ids
63+
# ---------------添加额外的参数------------------------
64+
params["messages"] = messages
65+
params["prompt"] = text
66+
params["stop"].extend(self.stop)
67+
params["stop_words_ids"] = self.stop_words_ids
68+
params["input_ids"] = input_ids
69+
# ---------------添加额外的参数------------------------
70+
async for ret in self.backend.stream_chat(params=params):
71+
yield json.dumps(ret).encode() + b"\0"
72+
73+
except torch.cuda.OutOfMemoryError as e:
74+
ret = {
75+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
76+
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
77+
}
78+
yield json.dumps(ret).encode() + b"\0"
79+
except (ValueError, RuntimeError) as e:
80+
logger.info(e)
81+
ret = {
82+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
83+
"error_code": ErrorCode.INTERNAL_ERROR,
84+
}
85+
yield json.dumps(ret).encode() + b"\0"
86+
87+
def get_embeddings(self, params):
88+
return super().get_embeddings(params)
89+
90+
91+
if __name__ == "__main__":
92+
PhiWorker.run()

0 commit comments

Comments
 (0)