Skip to content

Commit 4a9a25a

Browse files
committed
支持jina-reranker-m0
1 parent aeb0d8a commit 4a9a25a

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

gpt_server/model_worker/embedding.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,22 @@ def __init__(
4949
logger.info(f"device: {self.client.device}")
5050
self.client.set_processor(model_path)
5151
self.client.eval()
52-
elif "rerank" in self.mode:
52+
elif "vl_rerank" == self.mode:
53+
self.client = AutoModel.from_pretrained(
54+
model_path,
55+
torch_dtype="auto",
56+
trust_remote_code=True,
57+
# attn_implementation="flash_attention_2",
58+
)
59+
60+
self.client.to("cuda") # or 'cpu' if no GPU is available
61+
self.client.eval()
62+
elif "rerank" == self.mode:
5363
self.client = sentence_transformers.CrossEncoder(
5464
model_name=model_path, **model_kwargs
5565
)
5666
logger.warning("正在使用 rerank 模型...")
57-
elif "embedding" in self.mode:
67+
elif "embedding" == self.mode:
5868
self.client = sentence_transformers.SentenceTransformer(
5969
model_path, **model_kwargs
6070
)
@@ -79,6 +89,30 @@ async def get_embeddings(self, params):
7989
sentence_pairs = [[query, inp] for inp in texts]
8090
scores = self.client.predict(sentence_pairs)
8191
embedding = [[float(score)] for score in scores]
92+
elif self.mode == "vl_rerank":
93+
query = params.get("query", None)
94+
token_num = 0
95+
sentence_pairs = [[query, inp] for inp in texts]
96+
query_type = doc_type = "text"
97+
if (
98+
query.startswith("http://")
99+
or query.startswith("https://")
100+
or "data:" in query
101+
):
102+
query_type = "image"
103+
if (
104+
texts[0].startswith("http://")
105+
or texts[0].startswith("https://")
106+
or "data:" in texts[0]
107+
):
108+
doc_type = "image"
109+
scores = self.client.compute_score(
110+
sentence_pairs,
111+
max_length=1024 * 2,
112+
query_type=query_type,
113+
doc_type=doc_type,
114+
)
115+
embedding = [[float(score)] for score in scores]
82116
elif self.mode == "clip_text_model":
83117
token_num = 0
84118
if isinstance(texts[0], dict):

gpt_server/model_worker/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def get_embedding_mode(model_path: str):
5454
from infinity_emb.inference.select_model import get_engine_type_from_config
5555

5656
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
57+
architectures = getattr(model_config, "architectures", [])
58+
if "JinaVLForRanking" in architectures:
59+
logger.warning("model_type: JinaVLForRanking")
60+
return "vl_rerank"
5761
model_type_text = getattr(
5862
getattr(model_config, "text_config", {}), "model_type", None
5963
)
@@ -76,14 +80,13 @@ def get_embedding_mode(model_path: str):
7680
engine_type_str = str(engine_type)
7781

7882
if "EmbedderEngine" in engine_type_str:
79-
mode = "embedding"
83+
return "embedding"
8084
elif "RerankEngine" in engine_type_str:
81-
mode = "rerank"
85+
return "rerank"
8286
elif "ImageEmbedEngine" in engine_type_str:
83-
mode = model_type or "image"
87+
return model_type or "image"
8488
elif "PredictEngine" in engine_type_str:
85-
mode = "classify"
86-
return mode
89+
return "classify"
8790

8891

8992
if __name__ == "__main__":

gpt_server/serving/openai_api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ async def timing_tasks():
134134

135135
while True:
136136
try:
137+
# ret = await fetch_remote(controller_address + "/refresh_all_workers")
137138
models = await fetch_remote(
138139
controller_address + "/list_models", None, "models"
139140
)

0 commit comments

Comments
 (0)