Skip to content

Commit c3071cd

Browse files
committed
支持 qwen3 reranker 的 vllm 后端
1 parent ada0069 commit c3071cd

File tree

5 files changed

+63
-14
lines changed

5 files changed

+63
-14
lines changed

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ def run(cls):
311311
parser.add_argument("--port", type=int, default=None)
312312
# model_type
313313
parser.add_argument("--model_type", type=str, default="auto")
314+
# hf_overrides
315+
parser.add_argument("--hf_overrides", type=str, default="")
314316
args = parser.parse_args()
315317
os.environ["num_gpus"] = str(args.num_gpus)
316318
if args.backend == "vllm":
@@ -332,6 +334,8 @@ def run(cls):
332334
os.environ["vad_model"] = args.vad_model
333335
if args.punc_model:
334336
os.environ["punc_model"] = args.punc_model
337+
if args.hf_overrides:
338+
os.environ["hf_overrides"] = args.hf_overrides
335339

336340
os.environ["model_type"] = args.model_type
337341
os.environ["enable_prefix_caching"] = args.enable_prefix_caching

gpt_server/model_worker/embedding_vllm.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import os
22
from typing import List
3-
import asyncio
43
from loguru import logger
54

6-
from infinity_emb import AsyncEngineArray, EngineArgs, AsyncEmbeddingEngine
7-
from infinity_emb.inference.select_model import get_engine_type_from_config
85
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
96
from gpt_server.model_worker.utils import get_embedding_mode
107
import numpy as np
11-
from vllm import LLM
8+
from vllm import LLM, EmbeddingRequestOutput, ScoringRequestOutput
9+
from gpt_server.settings import get_model_config
1210

1311
label_to_category = {
1412
"S": "sexual",
@@ -23,6 +21,24 @@
2321
}
2422

2523

24+
def template_format(queries: List[str], documents: List[str]):
25+
model_config = get_model_config()
26+
hf_overrides = model_config.hf_overrides
27+
if hf_overrides:
28+
if hf_overrides["architectures"][0] == "Qwen3ForSequenceClassification":
29+
logger.info("使用 Qwen3ForSequenceClassification 模板格式化...")
30+
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
31+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
32+
instruction = "Given a web search query, retrieve relevant passages that answer the query"
33+
34+
query_template = f"{prefix}<Instruct>: {instruction}\n<Query>: {{query}}\n"
35+
document_template = f"<Document>: {{doc}}{suffix}"
36+
queries = [query_template.format(query=query) for query in queries]
37+
documents = [document_template.format(doc=doc) for doc in documents]
38+
return queries, documents
39+
return queries, documents
40+
41+
2642
class EmbeddingWorker(ModelWorkerBase):
2743
def __init__(
2844
self,
@@ -44,18 +60,20 @@ def __init__(
4460
conv_template,
4561
model_type="embedding",
4662
)
47-
tensor_parallel_size = int(os.getenv("num_gpus", "1"))
48-
max_model_len = os.getenv("max_model_len", None)
49-
gpu_memory_utilization = float(os.getenv("gpu_memory_utilization", 0.6))
50-
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))
51-
63+
model_config = get_model_config()
64+
hf_overrides = model_config.hf_overrides
5265
self.mode = get_embedding_mode(model_path=model_path)
66+
runner = "auto"
67+
if self.model == "rerank":
68+
runner = "pooling"
5369
self.engine = LLM(
5470
model=model_path,
55-
tensor_parallel_size=tensor_parallel_size,
56-
max_model_len=max_model_len,
57-
gpu_memory_utilization=gpu_memory_utilization,
58-
enable_prefix_caching=enable_prefix_caching,
71+
tensor_parallel_size=model_config.num_gpus,
72+
max_model_len=model_config.max_model_len,
73+
gpu_memory_utilization=model_config.gpu_memory_utilization,
74+
enable_prefix_caching=model_config.enable_prefix_caching,
75+
runner=runner,
76+
hf_overrides=hf_overrides,
5977
)
6078

6179
logger.warning(f"模型:{model_names[0]}")
@@ -69,13 +87,20 @@ async def get_embeddings(self, params):
6987
if self.mode == "embedding":
7088
texts = list(map(lambda x: x.replace("\n", " "), texts))
7189
# ----------
72-
outputs = self.engine.embed(texts)
90+
outputs: list[EmbeddingRequestOutput] = self.engine.embed(texts)
7391
embedding = [o.outputs.embedding for o in outputs]
7492
embeddings_np = np.array(embedding)
7593
# ------ L2归一化(沿axis=1,即对每一行进行归一化)-------
7694
norm = np.linalg.norm(embeddings_np, ord=2, axis=1, keepdims=True)
7795
normalized_embeddings_np = embeddings_np / norm
7896
embedding = normalized_embeddings_np.tolist()
97+
elif self.mode == "rerank":
98+
query = params.get("query", None)
99+
data_1 = [query] * len(texts)
100+
data_2 = texts
101+
data_1, data_2 = template_format(queries=data_1, documents=data_2)
102+
scores: list[ScoringRequestOutput] = self.engine.score(data_1, data_2)
103+
embedding = [[score.outputs.score] for score in scores]
79104

80105
ret["embedding"] = embedding
81106
return ret

gpt_server/script/config_example.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,21 @@ models:
9393
workers:
9494
- gpus:
9595
- 2
96+
# 部署 qwen3-reranker 样例
97+
- qwen3-reranker:
98+
alias: null
99+
enable: true
100+
model_config:
101+
model_name_or_path: /home/dev/model/Qwen/Qwen3-Reranker-0___6B/
102+
dtype: auto
103+
task_type: reranker
104+
hf_overrides: { "architectures": [ "Qwen3ForSequenceClassification" ], "classifier_from_token": [ "no", "yes" ], "is_original_qwen3_reranker": True }
105+
model_type: embedding
106+
work_mode: vllm
107+
device: gpu
108+
workers:
109+
- gpus:
110+
- 6
96111

97112
- jina-reranker:
98113
# 多模态多语言的重排模型,这个模型task_type 只能是 auto

gpt_server/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ class ModelConfig(BaseSettings):
1212
dtype: str = "auto"
1313
num_gpus: int = 1
1414
lora: str | None = None
15+
hf_overrides: dict | None = None
16+
"""HuggingFace 配置覆盖参数"""
1517

1618

1719
def get_model_config() -> ModelConfig:

gpt_server/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def start_model_worker(config: dict):
211211
vad_model = engine_config.get("vad_model", "")
212212
punc_model = engine_config.get("punc_model", "")
213213
task_type = engine_config.get("task_type", "auto")
214+
hf_overrides = engine_config.get("hf_overrides", "")
214215

215216
else:
216217
logger.error(
@@ -315,6 +316,8 @@ def start_model_worker(config: dict):
315316
cmd += f" --vad_model '{vad_model}'"
316317
if punc_model:
317318
cmd += f" --vad_model '{punc_model}'"
319+
if hf_overrides:
320+
cmd += f" --hf_overrides '{json.dumps(hf_overrides)}'"
318321
p = Process(target=run_cmd, args=(cmd,))
319322
# p.start()
320323
process.append(p)

0 commit comments

Comments
 (0)