@@ -49,12 +49,22 @@ def __init__(
4949 logger .info (f"device: { self .client .device } " )
5050 self .client .set_processor (model_path )
5151 self .client .eval ()
52- elif "rerank" in self .mode :
52+ elif "vl_rerank" == self .mode :
53+ self .client = AutoModel .from_pretrained (
54+ model_path ,
55+ torch_dtype = "auto" ,
56+ trust_remote_code = True ,
57+ # attn_implementation="flash_attention_2",
58+ )
59+
60+ self .client .to ("cuda" ) # or 'cpu' if no GPU is available
61+ self .client .eval ()
62+ elif "rerank" == self .mode :
5363 self .client = sentence_transformers .CrossEncoder (
5464 model_name = model_path , ** model_kwargs
5565 )
5666 logger .warning ("正在使用 rerank 模型..." )
57- elif "embedding" in self .mode :
67+ elif "embedding" == self .mode :
5868 self .client = sentence_transformers .SentenceTransformer (
5969 model_path , ** model_kwargs
6070 )
@@ -79,6 +89,30 @@ async def get_embeddings(self, params):
7989 sentence_pairs = [[query , inp ] for inp in texts ]
8090 scores = self .client .predict (sentence_pairs )
8191 embedding = [[float (score )] for score in scores ]
92+ elif self .mode == "vl_rerank" :
93+ query = params .get ("query" , None )
94+ token_num = 0
95+ sentence_pairs = [[query , inp ] for inp in texts ]
96+ query_type = doc_type = "text"
97+ if (
98+ query .startswith ("http://" )
99+ or query .startswith ("https://" )
100+ or "data:" in query
101+ ):
102+ query_type = "image"
103+ if (
104+ texts [0 ].startswith ("http://" )
105+ or texts [0 ].startswith ("https://" )
106+ or "data:" in texts [0 ]
107+ ):
108+ doc_type = "image"
109+ scores = self .client .compute_score (
110+ sentence_pairs ,
111+ max_length = 1024 * 2 ,
112+ query_type = query_type ,
113+ doc_type = doc_type ,
114+ )
115+ embedding = [[float (score )] for score in scores ]
82116 elif self .mode == "clip_text_model" :
83117 token_num = 0
84118 if isinstance (texts [0 ], dict ):
0 commit comments