Skip to content

Commit b842d01

Browse files
committed
修改架构 支持 jinaai/jina-reranker-v3
1 parent 6223df4 commit b842d01

File tree

3 files changed

+138
-128
lines changed

3 files changed

+138
-128
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
## 📘 配置文档
5252

5353

54+
- **[GPT Server - DeepWiki文档(可直接AI提问使用方式)](https://deepwiki.com/shell-nlp/gpt_server "deepwiki文档")**
55+
<br>
56+
5457
- **[配置详细说明](https://blog.csdn.net/q506610466/article/details/151360406 "详细配置说明")**
5558
<br>
5659

@@ -61,6 +64,7 @@
6164
<summary><b>2025</b></summary>
6265

6366
```plaintext
67+
2025-11-16 支持了 jinaai/jina-reranker-v3 模型
6468
2025-10-25 支持了 qwen_image 文生图模型
6569
2025-9-7 支持了 文本编辑模型 (代码样例见gpt_server/tests/test_image_edit.py)
6670
2025-8-8 初步支持了 embedding 的 vllm 加速
@@ -135,7 +139,7 @@
135139
* [X] 支持 文生图 模型
136140
* [X] 支持 图片编辑 模型
137141
* [X] 支持 Responses API
138-
* [ ] 支持 pip install 方式进行安装
142+
139143

140144

141145
## ⚙️ 快速开始
@@ -272,8 +276,9 @@ Chat UI界面:
272276
[SGLang](https://docs.sglang.ai/supported_models/generative_models.html)
273277

274278
#### 注意:
275-
- **现可以通过在 `config.yaml`中 设置 `model_type: auto`** 支持所有vllm/sglang/lmdeploy 当前版本已经支持的大语言模型和多模态语言模型,embedding、reranker等非语言模型除外。
276-
- 下面的项目兼容表未来将移除或者重构
279+
- **现可以通过在 `config.yaml`中 设置 `model_type: auto`** 支持所有vllm/sglang/lmdeploy 当前版本已经支持的大语言模型和多模态语言模型。
280+
281+
- 下面的项目兼容表未来将移除或者重构,没有在表中的模型也可能兼容,实际情况情参考官方。
277282

278283
### **LLM**
279284

@@ -298,9 +303,8 @@ Chat UI界面:
298303
|InternVL2.5--3.5 | internvl | × | × ||| × |
299304
| MiniCPM-V-2.6 | minicpmv | × ||| × | × |
300305
| MiniCPM-V-4.5 | minicpmv | × || × | × | × |
301-
| Qwen2-VL | qwen | × || × |||
302-
| Qwen2.5-VL | qwen | × || × |||
303-
| QVQ | qwen | × || × | × | × |
306+
| Qwen-VL 2.0--3.0 | qwen | × |||||
307+
| QVQ | qwen | × |||||
304308
<br>
305309

306310
### Embedding/Rerank/Classify模型
@@ -332,6 +336,7 @@ Chat UI界面:
332336
| jina-reranker-m0 || × |× |
333337
| bge-reranker |||× |
334338
| bce-reranker |||× |
339+
| jina-reranker-v3 || × |× |
335340

336341
目前 **ritrieve_zh_v1** C-MTEB榜单排行第一(MTEB: https://huggingface.co/spaces/mteb/leaderboard)
337342

gpt_server/model_worker/embedding_sentence_transformers.py

Lines changed: 4 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
import asyncio
21
import os
32
from typing import List
43

5-
import sentence_transformers
6-
import torch
7-
from transformers import AutoConfig, AutoModel
84
from loguru import logger
95
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
106
from gpt_server.model_worker.utils import (
11-
load_base64_or_url,
12-
get_embedding_mode,
13-
is_base64_image,
7+
PoolingModel,
148
)
159

1610

@@ -40,119 +34,14 @@ def __init__(
4034
else:
4135
device = "cuda"
4236
logger.warning(f"使用{device}加载...")
43-
model_kwargs = {"device": device}
44-
if device == "cuda":
45-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46-
# TODO
47-
self.mode = get_embedding_mode(model_path=model_path)
48-
self.encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
49-
if "clip_text_model" in self.mode: # clip text 模型
50-
self.client = AutoModel.from_pretrained(model_path, trust_remote_code=True)
51-
self.client.to(device)
52-
logger.info(f"device: {self.client.device}")
53-
self.client.set_processor(model_path)
54-
self.client.eval()
55-
elif "vl_rerank" == self.mode:
56-
self.client = AutoModel.from_pretrained(
57-
model_path,
58-
torch_dtype="auto",
59-
trust_remote_code=True,
60-
# attn_implementation="flash_attention_2",
61-
)
62-
self.client.to(device)
63-
self.client.eval()
64-
elif "rerank" == self.mode:
65-
self.client = sentence_transformers.CrossEncoder(
66-
model_name=model_path, **model_kwargs
67-
)
68-
logger.warning("正在使用 rerank 模型...")
69-
elif "embedding" == self.mode:
70-
self.client = sentence_transformers.SentenceTransformer(
71-
model_path, **model_kwargs
72-
)
73-
logger.warning("正在使用 embedding 模型...")
37+
self.pool_model = PoolingModel(model_path=model_path)
7438
logger.warning(f"模型:{model_names[0]}")
75-
logger.warning(f"正在使用 {self.mode} 模型...")
7639

7740
async def get_embeddings(self, params):
7841
self.call_ct += 1
79-
ret = {"embedding": [], "token_num": 0}
8042
texts = params["input"]
81-
embedding = []
82-
token_num = 0
83-
if self.mode == "embedding":
84-
outputs = self.client.tokenize(texts)
85-
token_num = outputs["input_ids"].size(0) * outputs["input_ids"].size(1)
86-
texts = list(map(lambda x: x.replace("\n", " "), texts))
87-
embedding = self.client.encode(texts, **self.encode_kwargs).tolist()
88-
elif self.mode == "rerank":
89-
query = params.get("query", None)
90-
# outputs = self.client.tokenizer.tokenize(texts)
91-
# token_num = len(outputs)
92-
# TODO 暂时不计算 rerank token num
93-
sentence_pairs = [[query, inp] for inp in texts]
94-
scores = self.client.predict(sentence_pairs)
95-
embedding = [[float(score)] for score in scores]
96-
elif self.mode == "vl_rerank":
97-
query = params.get("query", None)
98-
sentence_pairs = [[query, inp] for inp in texts]
99-
query_type = doc_type = "text"
100-
if (
101-
query.startswith("http://")
102-
or query.startswith("https://")
103-
or is_base64_image(query)
104-
):
105-
query_type = "image"
106-
if (
107-
texts[0].startswith("http://")
108-
or texts[0].startswith("https://")
109-
or is_base64_image(texts[0])
110-
):
111-
doc_type = "image"
112-
scores = self.client.compute_score(
113-
sentence_pairs,
114-
max_length=1024 * 2,
115-
query_type=query_type,
116-
doc_type=doc_type,
117-
)
118-
if isinstance(scores, float):
119-
scores = [scores]
120-
embedding = [[float(score)] for score in scores]
121-
elif self.mode == "clip_text_model":
122-
if isinstance(texts[0], dict):
123-
text = [i["text"] for i in texts]
124-
text = list(map(lambda x: x.replace("\n", " "), text))
125-
126-
images = [i["image"] for i in texts]
127-
coro_list = []
128-
for i in images:
129-
coro = load_base64_or_url(base64_or_url=i)
130-
coro_list.append(coro)
131-
result_images = await asyncio.gather(*coro_list)
132-
133-
embedding = self.client.encode(
134-
images=result_images,
135-
text=text,
136-
).tolist()
137-
elif isinstance(texts[0], str):
138-
if "http" in texts[0] or is_base64_image(texts[0]): # 图片
139-
images = texts
140-
coro_list = []
141-
for i in images:
142-
coro = load_base64_or_url(base64_or_url=i)
143-
coro_list.append(coro)
144-
result_images = await asyncio.gather(*coro_list)
145-
embedding = self.client.encode(
146-
images=result_images,
147-
).tolist()
148-
else: # 文本
149-
embedding = self.client.encode(
150-
text=texts,
151-
).tolist()
152-
else:
153-
raise Exception(f"不支持的类型 mode: {self.mode}")
154-
ret["embedding"] = embedding
155-
ret["token_num"] = token_num
43+
query = params.get("query", None)
44+
ret = self.pool_model.pooling(query=query, documents=texts)
15645
return ret
15746

15847

gpt_server/model_worker/utils.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import os
77
from PIL import Image
88
import re
9+
import torch
10+
from transformers import AutoConfig
11+
from transformers import AutoModel
12+
import sentence_transformers
913

1014

1115
def is_base64_image(data_string):
@@ -63,6 +67,124 @@ async def load_base64_or_url(base64_or_url) -> io.BytesIO:
6367
return bytes_io
6468

6569

70+
class PoolingModel:
71+
def __init__(self, model_path: str):
72+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73+
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
74+
architectures = getattr(model_config, "architectures", [])
75+
self.model = None
76+
self._pooling = None
77+
if "JinaForRanking" in architectures:
78+
self.model = AutoModel.from_pretrained(
79+
model_path,
80+
dtype="auto",
81+
trust_remote_code=True,
82+
)
83+
self.model.eval()
84+
self.model.to(device) # Move model to device
85+
86+
def pooling(self, query: str, documents: list):
87+
results = self.model.rerank(query, documents)
88+
embedding = [[i["relevance_score"]] for i in results]
89+
ret = {}
90+
ret["embedding"] = embedding
91+
ret["token_num"] = 0
92+
return ret
93+
94+
self._pooling = self.pooling
95+
elif "JinaVLForRanking" in architectures:
96+
self.model = AutoModel.from_pretrained(
97+
model_path,
98+
torch_dtype="auto",
99+
trust_remote_code=True,
100+
# attn_implementation="flash_attention_2",
101+
)
102+
self.model.to(device)
103+
self.model.eval()
104+
logger.warning("model_type: JinaVLForRanking")
105+
106+
def pooling(self, query: str, documents: list):
107+
texts = documents
108+
sentence_pairs = [[query, inp] for inp in texts]
109+
query_type = doc_type = "text"
110+
111+
if (
112+
query.startswith("http://")
113+
or query.startswith("https://")
114+
or is_base64_image(query)
115+
):
116+
query_type = "image"
117+
if (
118+
texts
119+
and texts[0]
120+
and (
121+
texts[0].startswith("http://")
122+
or texts[0].startswith("https://")
123+
or is_base64_image(texts[0])
124+
)
125+
):
126+
doc_type = "image"
127+
scores = self.model.compute_score(
128+
sentence_pairs,
129+
max_length=1024 * 2,
130+
query_type=query_type,
131+
doc_type=doc_type,
132+
)
133+
if isinstance(scores, float):
134+
scores = [scores]
135+
embedding = [[float(score)] for score in scores]
136+
ret = {}
137+
ret["embedding"] = embedding
138+
ret["token_num"] = 0
139+
return ret
140+
141+
self._pooling = self.pooling
142+
else:
143+
mode = get_embedding_mode(model_path=model_path)
144+
if "embedding" == mode:
145+
self.model = sentence_transformers.SentenceTransformer(model_path)
146+
logger.warning("正在使用 embedding 模型...")
147+
encode_kwargs = {"normalize_embeddings": True, "batch_size": 64}
148+
149+
def pooling(self, query: str, documents: list = None):
150+
texts = documents
151+
outputs = self.model.tokenize(texts)
152+
token_num = outputs["input_ids"].size(0) * outputs[
153+
"input_ids"
154+
].size(1)
155+
texts = list(map(lambda x: x.replace("\n", " "), texts))
156+
embedding = self.model.encode(texts, **encode_kwargs).tolist()
157+
ret = {}
158+
ret["embedding"] = embedding
159+
ret["token_num"] = token_num
160+
return ret
161+
162+
self._pooling = self.pooling
163+
164+
elif "rerank" == mode:
165+
self.model = sentence_transformers.CrossEncoder(model_name=model_path)
166+
logger.warning("正在使用 rerank 模型...")
167+
168+
def pooling(self, query: str, documents: list):
169+
sentence_pairs = [[query, doc] for doc in documents]
170+
scores = self.model.predict(sentence_pairs)
171+
embedding = [[float(score)] for score in scores]
172+
ret = {}
173+
ret["embedding"] = embedding
174+
ret["token_num"] = 0 # Rerank token num not typically calculated
175+
return ret
176+
177+
self._pooling = self.pooling
178+
179+
else:
180+
raise Exception(f"不支持的类型 mode: {mode}")
181+
182+
def pooling(self, query, documents):
183+
if self._pooling is None:
184+
raise Exception("Model is not initialized or mode is not supported.")
185+
return self._pooling(self, query, documents)
186+
187+
66188
def get_embedding_mode(model_path: str):
67189
"""获取模型的类型"""
68190
task_type = os.environ.get("task_type", "auto")
@@ -72,20 +194,14 @@ def get_embedding_mode(model_path: str):
72194
return "rerank"
73195
elif task_type == "classify":
74196
return "classify"
75-
from transformers import AutoConfig
76197

77198
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
78-
architectures = getattr(model_config, "architectures", [])
79199
model_type_text = getattr(
80200
getattr(model_config, "text_config", {}), "model_type", None
81201
)
82202
logger.warning(f"model_type: {model_type_text}")
83203

84204
model_type = model_type_text
85-
# TODO --------- 在这里进行大过滤 ---------
86-
if "JinaVLForRanking" in architectures:
87-
logger.warning("model_type: JinaVLForRanking")
88-
return "vl_rerank"
89205
# --------- 在这里进行大过滤 ---------
90206
from infinity_emb import EngineArgs
91207

@@ -114,5 +230,5 @@ def get_embedding_mode(model_path: str):
114230
if __name__ == "__main__":
115231

116232
# 示例用法
117-
r = get_embedding_mode("/home/dev/model/jinaai/jina-reranker-m0/")
233+
r = get_embedding_mode("/home/dev/model/jinaai/jina-reranker-v3/")
118234
print(r)

0 commit comments

Comments
 (0)