Skip to content

Commit 337fa4e

Browse files
committed
优化 embedding
1 parent ee47775 commit 337fa4e

File tree

3 files changed

+31
-38
lines changed

3 files changed

+31
-38
lines changed

gpt_server/model_worker/embedding.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from transformers import AutoConfig, AutoModel
88
from loguru import logger
99
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
10-
from gpt_server.model_worker.utils import load_base64_or_url
10+
from gpt_server.model_worker.utils import load_base64_or_url, get_embedding_mode
1111

1212

1313
class EmbeddingWorker(ModelWorkerBase):
@@ -38,41 +38,27 @@ def __init__(
3838
logger.warning(f"使用{device}加载...")
3939
model_kwargs = {"device": device}
4040
# TODO
41-
self.mode = "embedding"
42-
model_type = getattr(
43-
getattr(self.model_config, "text_config", {}), "model_type", None
44-
)
45-
logger.warning(f"model_type: {model_type}")
46-
if "clip_text_model" in model_type: # clip text 模型
47-
self.mode = "clip_text_model"
48-
self.client = AutoModel.from_pretrained(
49-
model_path, trust_remote_code=True
50-
) # You must set trust_remote_code=True
41+
self.mode = get_embedding_mode(model_path=model_path)
42+
self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
43+
if "clip_text_model" in self.mode: # clip text 模型
44+
self.client = AutoModel.from_pretrained(model_path, trust_remote_code=True)
5145
if device == "cuda":
5246
self.client.to(
5347
torch.device("cuda" if torch.cuda.is_available() else "cpu")
5448
)
5549
logger.info(f"device: {self.client.device}")
5650
self.client.set_processor(model_path)
5751
self.client.eval()
58-
else:
59-
self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
60-
61-
# rerank
62-
for model_name in model_names:
63-
if "rerank" in model_name:
64-
self.mode = "rerank"
65-
break
66-
if self.mode == "rerank":
67-
self.client = sentence_transformers.CrossEncoder(
68-
model_name=model_path, **model_kwargs
69-
)
70-
logger.warning("正在使用 rerank 模型...")
71-
elif self.mode == "embedding":
72-
self.client = sentence_transformers.SentenceTransformer(
73-
model_path, **model_kwargs
74-
)
75-
logger.warning("正在使用 embedding 模型...")
52+
elif "rerank" in self.mode:
53+
self.client = sentence_transformers.CrossEncoder(
54+
model_name=model_path, **model_kwargs
55+
)
56+
logger.warning("正在使用 rerank 模型...")
57+
elif "embedding" in self.mode:
58+
self.client = sentence_transformers.SentenceTransformer(
59+
model_path, **model_kwargs
60+
)
61+
logger.warning("正在使用 embedding 模型...")
7662
logger.warning(f"模型:{model_names[0]}")
7763

7864
async def get_embeddings(self, params):

gpt_server/model_worker/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@ def get_embedding_mode(model_path: str):
4848
model_type_text = getattr(
4949
getattr(model_config, "text_config", {}), "model_type", None
5050
)
51-
model_type_vison = getattr(
52-
getattr(model_config, "vision_config", {}), "model_type", None
53-
)
54-
print(model_type_vison, model_type_text)
55-
model_type = model_type_vison or model_type_text
51+
logger.warning(f"model_type: {model_type_text}")
52+
# model_type_vison = getattr(
53+
# getattr(model_config, "vision_config", {}), "model_type", None
54+
# )
55+
# print(model_type_vison, model_type_text)
56+
model_type = model_type_text
5657

5758
mode = "embedding"
5859
engine_args = EngineArgs(
@@ -79,5 +80,5 @@ def get_embedding_mode(model_path: str):
7980
if __name__ == "__main__":
8081

8182
# 示例用法
82-
r = get_embedding_mode("BAAI/BGE-VL-MLLM-S1")
83+
r = get_embedding_mode("/home/dev/model/BAAI/bge-m3/")
8384
print(r)

tests/test_openai_embedding.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from openai import OpenAI
22
from rich import print
3+
import numpy as np
34

45
# 新版本 opnai
56
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8082/v1")
67
# model: acge_text_embedding yinka zpoint
7-
data = client.embeddings.create(model="piccolo-base-zh", input=["你是谁", "你是谁"])
8-
9-
print(data.data)
8+
response = client.embeddings.create(model="bge-m3", input=["我喜欢你", "我也喜欢你"])
9+
print(response.data)
10+
embeddings = [np.array(item.embedding) for item in response.data] # 转为NumPy数组
11+
v_a = embeddings[0].reshape(1, -1) # 向量a
12+
v_b = embeddings[1].reshape(-1, 1) # 向量b
13+
# 计算余弦相似度
14+
similarity = np.dot(v_a, v_b)[0][0]
15+
print(f"余弦相似度: {similarity:.4f}")

0 commit comments

Comments
 (0)