11import os
22from typing import List
3- import asyncio
43from loguru import logger
54
6- from infinity_emb import AsyncEngineArray , EngineArgs , AsyncEmbeddingEngine
7- from infinity_emb .inference .select_model import get_engine_type_from_config
85from gpt_server .model_worker .base .model_worker_base import ModelWorkerBase
96from gpt_server .model_worker .utils import get_embedding_mode
107import numpy as np
11- from vllm import LLM
8+ from vllm import LLM , EmbeddingRequestOutput , ScoringRequestOutput
9+ from gpt_server .settings import get_model_config
1210
1311label_to_category = {
1412 "S" : "sexual" ,
2321}
2422
2523
24+ def template_format (queries : List [str ], documents : List [str ]):
25+ model_config = get_model_config ()
26+ hf_overrides = model_config .hf_overrides
27+ if hf_overrides :
28+ if hf_overrides ["architectures" ][0 ] == "Qwen3ForSequenceClassification" :
29+ logger .info ("使用 Qwen3ForSequenceClassification 模板格式化..." )
30+ prefix = '<|im_start|>system\n Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n <|im_start|>user\n '
31+ suffix = "<|im_end|>\n <|im_start|>assistant\n <think>\n \n </think>\n \n "
32+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
33+
34+ query_template = f"{ prefix } <Instruct>: { instruction } \n <Query>: {{query}}\n "
35+ document_template = f"<Document>: {{doc}}{ suffix } "
36+ queries = [query_template .format (query = query ) for query in queries ]
37+ documents = [document_template .format (doc = doc ) for doc in documents ]
38+ return queries , documents
39+ return queries , documents
40+
41+
2642class EmbeddingWorker (ModelWorkerBase ):
2743 def __init__ (
2844 self ,
@@ -44,18 +60,20 @@ def __init__(
4460 conv_template ,
4561 model_type = "embedding" ,
4662 )
47- tensor_parallel_size = int (os .getenv ("num_gpus" , "1" ))
48- max_model_len = os .getenv ("max_model_len" , None )
49- gpu_memory_utilization = float (os .getenv ("gpu_memory_utilization" , 0.6 ))
50- enable_prefix_caching = bool (os .getenv ("enable_prefix_caching" , False ))
51-
63+ model_config = get_model_config ()
64+ hf_overrides = model_config .hf_overrides
5265 self .mode = get_embedding_mode (model_path = model_path )
66+ runner = "auto"
67+ if self .model == "rerank" :
68+ runner = "pooling"
5369 self .engine = LLM (
5470 model = model_path ,
55- tensor_parallel_size = tensor_parallel_size ,
56- max_model_len = max_model_len ,
57- gpu_memory_utilization = gpu_memory_utilization ,
58- enable_prefix_caching = enable_prefix_caching ,
71+ tensor_parallel_size = model_config .num_gpus ,
72+ max_model_len = model_config .max_model_len ,
73+ gpu_memory_utilization = model_config .gpu_memory_utilization ,
74+ enable_prefix_caching = model_config .enable_prefix_caching ,
75+ runner = runner ,
76+ hf_overrides = hf_overrides ,
5977 )
6078
6179 logger .warning (f"模型:{ model_names [0 ]} " )
@@ -69,13 +87,20 @@ async def get_embeddings(self, params):
6987 if self .mode == "embedding" :
7088 texts = list (map (lambda x : x .replace ("\n " , " " ), texts ))
7189 # ----------
72- outputs = self .engine .embed (texts )
90+ outputs : list [ EmbeddingRequestOutput ] = self .engine .embed (texts )
7391 embedding = [o .outputs .embedding for o in outputs ]
7492 embeddings_np = np .array (embedding )
7593 # ------ L2归一化(沿axis=1,即对每一行进行归一化)-------
7694 norm = np .linalg .norm (embeddings_np , ord = 2 , axis = 1 , keepdims = True )
7795 normalized_embeddings_np = embeddings_np / norm
7896 embedding = normalized_embeddings_np .tolist ()
97+ elif self .mode == "rerank" :
98+ query = params .get ("query" , None )
99+ data_1 = [query ] * len (texts )
100+ data_2 = texts
101+ data_1 , data_2 = template_format (queries = data_1 , documents = data_2 )
102+ scores : list [ScoringRequestOutput ] = self .engine .score (data_1 , data_2 )
103+ embedding = [[score .outputs .score ] for score in scores ]
79104
80105 ret ["embedding" ] = embedding
81106 return ret
0 commit comments