66import os
77from PIL import Image
88import re
9+ import torch
10+ from transformers import AutoConfig
11+ from transformers import AutoModel
12+ import sentence_transformers
913
1014
1115def is_base64_image (data_string ):
@@ -63,6 +67,124 @@ async def load_base64_or_url(base64_or_url) -> io.BytesIO:
6367 return bytes_io
6468
6569
70+ class PoolingModel :
71+ def __init__ (self , model_path : str ):
72+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
73+ model_config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
74+ architectures = getattr (model_config , "architectures" , [])
75+ self .model = None
76+ self ._pooling = None
77+ if "JinaForRanking" in architectures :
78+ self .model = AutoModel .from_pretrained (
79+ model_path ,
80+ dtype = "auto" ,
81+ trust_remote_code = True ,
82+ )
83+ self .model .eval ()
84+ self .model .to (device ) # Move model to device
85+
86+ def pooling (self , query : str , documents : list ):
87+ results = self .model .rerank (query , documents )
88+ embedding = [[i ["relevance_score" ]] for i in results ]
89+ ret = {}
90+ ret ["embedding" ] = embedding
91+ ret ["token_num" ] = 0
92+ return ret
93+
94+ self ._pooling = self .pooling
95+ elif "JinaVLForRanking" in architectures :
96+ self .model = AutoModel .from_pretrained (
97+ model_path ,
98+ torch_dtype = "auto" ,
99+ trust_remote_code = True ,
100+ # attn_implementation="flash_attention_2",
101+ )
102+ self .model .to (device )
103+ self .model .eval ()
104+ logger .warning ("model_type: JinaVLForRanking" )
105+
106+ def pooling (self , query : str , documents : list ):
107+ texts = documents
108+ sentence_pairs = [[query , inp ] for inp in texts ]
109+ query_type = doc_type = "text"
110+
111+ if (
112+ query .startswith ("http://" )
113+ or query .startswith ("https://" )
114+ or is_base64_image (query )
115+ ):
116+ query_type = "image"
117+ if (
118+ texts
119+ and texts [0 ]
120+ and (
121+ texts [0 ].startswith ("http://" )
122+ or texts [0 ].startswith ("https://" )
123+ or is_base64_image (texts [0 ])
124+ )
125+ ):
126+ doc_type = "image"
127+ scores = self .model .compute_score (
128+ sentence_pairs ,
129+ max_length = 1024 * 2 ,
130+ query_type = query_type ,
131+ doc_type = doc_type ,
132+ )
133+ if isinstance (scores , float ):
134+ scores = [scores ]
135+ embedding = [[float (score )] for score in scores ]
136+ ret = {}
137+ ret ["embedding" ] = embedding
138+ ret ["token_num" ] = 0
139+ return ret
140+
141+ self ._pooling = self .pooling
142+ else :
143+ mode = get_embedding_mode (model_path = model_path )
144+ if "embedding" == mode :
145+ self .model = sentence_transformers .SentenceTransformer (model_path )
146+ logger .warning ("正在使用 embedding 模型..." )
147+ encode_kwargs = {"normalize_embeddings" : True , "batch_size" : 64 }
148+
149+ def pooling (self , query : str , documents : list = None ):
150+ texts = documents
151+ outputs = self .model .tokenize (texts )
152+ token_num = outputs ["input_ids" ].size (0 ) * outputs [
153+ "input_ids"
154+ ].size (1 )
155+ texts = list (map (lambda x : x .replace ("\n " , " " ), texts ))
156+ embedding = self .model .encode (texts , ** encode_kwargs ).tolist ()
157+ ret = {}
158+ ret ["embedding" ] = embedding
159+ ret ["token_num" ] = token_num
160+ return ret
161+
162+ self ._pooling = self .pooling
163+
164+ elif "rerank" == mode :
165+ self .model = sentence_transformers .CrossEncoder (model_name = model_path )
166+ logger .warning ("正在使用 rerank 模型..." )
167+
168+ def pooling (self , query : str , documents : list ):
169+ sentence_pairs = [[query , doc ] for doc in documents ]
170+ scores = self .model .predict (sentence_pairs )
171+ embedding = [[float (score )] for score in scores ]
172+ ret = {}
173+ ret ["embedding" ] = embedding
174+ ret ["token_num" ] = 0 # Rerank token num not typically calculated
175+ return ret
176+
177+ self ._pooling = self .pooling
178+
179+ else :
180+ raise Exception (f"不支持的类型 mode: { mode } " )
181+
182+ def pooling (self , query , documents ):
183+ if self ._pooling is None :
184+ raise Exception ("Model is not initialized or mode is not supported." )
185+ return self ._pooling (self , query , documents )
186+
187+
66188def get_embedding_mode (model_path : str ):
67189 """获取模型的类型"""
68190 task_type = os .environ .get ("task_type" , "auto" )
@@ -72,20 +194,14 @@ def get_embedding_mode(model_path: str):
72194 return "rerank"
73195 elif task_type == "classify" :
74196 return "classify"
75- from transformers import AutoConfig
76197
77198 model_config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
78- architectures = getattr (model_config , "architectures" , [])
79199 model_type_text = getattr (
80200 getattr (model_config , "text_config" , {}), "model_type" , None
81201 )
82202 logger .warning (f"model_type: { model_type_text } " )
83203
84204 model_type = model_type_text
85- # TODO --------- 在这里进行大过滤 ---------
86- if "JinaVLForRanking" in architectures :
87- logger .warning ("model_type: JinaVLForRanking" )
88- return "vl_rerank"
89205 # --------- 在这里进行大过滤 ---------
90206 from infinity_emb import EngineArgs
91207
@@ -114,5 +230,5 @@ def get_embedding_mode(model_path: str):
114230if __name__ == "__main__" :
115231
116232 # 示例用法
117- r = get_embedding_mode ("/home/dev/model/jinaai/jina-reranker-m0 /" )
233+ r = get_embedding_mode ("/home/dev/model/jinaai/jina-reranker-v3 /" )
118234 print (r )
0 commit comments