Skip to content

Commit 9d2fec4

Browse files
committed
优化 项目架构 以更好的支持 更多的 embedding 后端
1 parent 7b4db83 commit 9d2fec4

File tree

5 files changed

+98
-13
lines changed

5 files changed

+98
-13
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
from typing import List
3+
import asyncio
4+
from loguru import logger
5+
6+
from infinity_emb import AsyncEngineArray, EngineArgs, AsyncEmbeddingEngine
7+
from infinity_emb.inference.select_model import get_engine_type_from_config
8+
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
9+
from gpt_server.model_worker.utils import get_embedding_mode
10+
import torch
11+
import vllm
12+
from vllm import LLM
13+
14+
label_to_category = {
15+
"S": "sexual",
16+
"H": "hate",
17+
"HR": "harassment",
18+
"SH": "self-harm",
19+
"S3": "sexual/minors",
20+
"H2": "hate/threatening",
21+
"V2": "violence/graphic",
22+
"V": "violence",
23+
"OK": "OK",
24+
}
25+
26+
27+
class EmbeddingWorker(ModelWorkerBase):
28+
def __init__(
29+
self,
30+
controller_addr: str,
31+
worker_addr: str,
32+
worker_id: str,
33+
model_path: str,
34+
model_names: List[str],
35+
limit_worker_concurrency: int,
36+
conv_template: str = None, # type: ignore
37+
):
38+
super().__init__(
39+
controller_addr,
40+
worker_addr,
41+
worker_id,
42+
model_path,
43+
model_names,
44+
limit_worker_concurrency,
45+
conv_template,
46+
model_type="embedding",
47+
)
48+
tensor_parallel_size = int(os.getenv("num_gpus", "1"))
49+
max_model_len = os.getenv("max_model_len", None)
50+
gpu_memory_utilization = float(os.getenv("gpu_memory_utilization", 0.8))
51+
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))
52+
53+
self.mode = get_embedding_mode(model_path=model_path)
54+
self.engine = LLM(
55+
model=model_path,
56+
tensor_parallel_size=tensor_parallel_size,
57+
max_model_len=max_model_len,
58+
gpu_memory_utilization=gpu_memory_utilization,
59+
enable_prefix_caching=enable_prefix_caching,
60+
)
61+
62+
logger.warning(f"模型:{model_names[0]}")
63+
logger.warning(f"正在使用 {self.mode} 模型...")
64+
65+
async def get_embeddings(self, params):
66+
self.call_ct += 1
67+
ret = {"embedding": [], "token_num": 0}
68+
texts: list = params["input"]
69+
if self.mode == "embedding":
70+
usage = None
71+
texts = list(map(lambda x: x.replace("\n", " "), texts))
72+
# ----------
73+
outputs = self.engine.embed(prompts=texts)
74+
embedding = [o.outputs.embedding for o in outputs]
75+
76+
ret["embedding"] = embedding
77+
ret["token_num"] = usage
78+
return ret
79+
80+
81+
if __name__ == "__main__":
82+
EmbeddingWorker.run()

gpt_server/script/config_example.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ models:
8787
enable: true # false true
8888
model_config:
8989
model_name_or_path: /home/dev/model/Xorbits/bge-reranker-base/
90-
model_type: embedding_infinity # embedding_infinity/embedding
91-
work_mode: hf
90+
model_type: embedding
91+
work_mode: infinity # 可选 ["vllm", "infinity", "sentence_transformers"],但并不是所有后端都支持
9292
device: gpu # gpu / cpu
9393
workers:
9494
- gpus:
@@ -101,8 +101,8 @@ models:
101101
model_config:
102102
model_name_or_path: /home/dev/model/jinaai/jina-reranker-m0/
103103
task_type: auto # auto 、embedding 、 reranker 或者 classify 不设置这个参数,默认为 auto,自动识别可能会识别错误
104-
model_type: embedding # 这里仅支持 embedding
105-
work_mode: hf
104+
model_type: embedding
105+
work_mode: sentence_transformers # 可选 ["vllm", "infinity", "sentence_transformers"],但并不是所有后端都支持
106106
device: gpu
107107
workers:
108108
- gpus:
@@ -115,8 +115,8 @@ models:
115115
model_config:
116116
model_name_or_path: /home/dev/model/aspire/acge_text_embedding
117117
task_type: auto # auto 、embedding 、 reranker 或者 classify 不设置这个参数,默认为 auto,自动识别可能会识别错误
118-
model_type: embedding_infinity # embedding_infinity/embedding
119-
work_mode: hf
118+
model_type: embedding
119+
work_mode: infinity # 可选 ["vllm", "infinity", "sentence_transformers"],但并不是所有后端都支持
120120
device: gpu # gpu / cpu
121121
workers:
122122
- gpus:
@@ -128,8 +128,8 @@ models:
128128
enable: true
129129
model_config:
130130
model_name_or_path: /home/dev/model/BAAI/BGE-VL-base/
131-
model_type: embedding # 这里仅支持 embedding
132-
work_mode: hf
131+
model_type: embedding
132+
work_mode: sentence_transformers # 可选 ["vllm", "infinity", "sentence_transformers"],但并不是所有后端都支持
133133
device: gpu
134134
workers:
135135
- gpus:
@@ -141,8 +141,8 @@ models:
141141
enable: true
142142
model_config:
143143
model_name_or_path: /home/dev/model/KoalaAI/Text-Moderation
144-
model_type: embedding_infinity # embedding_infinity
145-
work_mode: hf
144+
model_type: embedding
145+
work_mode: infinity # 可选 ["vllm", "infinity", "sentence_transformers"],但并不是所有后端都支持
146146
device: gpu
147147
workers:
148148
- gpus:

gpt_server/serving/openai_api_server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def parse_env_var(cls, field_name: str, raw_val: str):
127127
async def timing_tasks():
128128
"""定时任务"""
129129
global model_address_map, models_
130-
logger.info("定时任务已启动!")
131130
controller_address = app_settings.controller_address
132131

133132
while True:

gpt_server/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def get_model_types():
127127
return model_types
128128

129129

130-
model_types = get_model_types()
130+
model_types = get_model_types() + ["embedding"]
131+
embedding_backend_type = ["vllm", "infinity", "sentence_transformers"]
131132

132133

133134
def start_model_worker(config: dict):
@@ -201,7 +202,6 @@ def start_model_worker(config: dict):
201202
f"不支持model_type: {model_type},仅支持{model_types}模型之一!"
202203
)
203204
sys.exit()
204-
py_path = f"-m gpt_server.model_worker.{model_type}"
205205

206206
model_names = model_name
207207
if model_config["alias"]:
@@ -240,7 +240,11 @@ def start_model_worker(config: dict):
240240
else:
241241
raise Exception("目前仅支持 CPU/GPU设备!")
242242
backend = model_config["work_mode"]
243+
if model_type == "embedding":
244+
assert backend in embedding_backend_type
245+
model_type = f"embedding_{backend}"
243246

247+
py_path = f"-m gpt_server.model_worker.{model_type}"
244248
cmd = (
245249
CUDA_VISIBLE_DEVICES
246250
+ run_mode

0 commit comments

Comments
 (0)