44from loguru import logger
55
66from infinity_emb import AsyncEngineArray , EngineArgs , AsyncEmbeddingEngine
7+ from infinity_emb .inference .select_model import get_engine_type_from_config
78from gpt_server .model_worker .base .model_worker_base import ModelWorkerBase
89
910label_to_category = {
@@ -49,30 +50,26 @@ def __init__(
4950 bettertransformer = True
5051 if model_type is not None and "deberta" in model_type :
5152 bettertransformer = False
52- self .engine : AsyncEmbeddingEngine = AsyncEngineArray .from_args (
53- [
54- EngineArgs (
55- model_name_or_path = model_path ,
56- engine = "torch" ,
57- embedding_dtype = "float32" ,
58- dtype = "float32" ,
59- device = device ,
60- bettertransformer = bettertransformer ,
61- )
62- ]
63- )[0 ]
53+ engine_args = EngineArgs (
54+ model_name_or_path = model_path ,
55+ engine = "torch" ,
56+ embedding_dtype = "float32" ,
57+ dtype = "float32" ,
58+ device = device ,
59+ bettertransformer = bettertransformer ,
60+ )
61+ engine_type = get_engine_type_from_config (engine_args )
62+ engine_type_str = str (engine_type )
63+ if "EmbedderEngine" in engine_type_str :
64+ self .mode = "embedding"
65+ elif "RerankEngine" in engine_type_str :
66+ self .mode = "rerank"
67+ elif "ImageEmbedEngine" in engine_type_str :
68+ self .mode = "image"
69+ self .engine : AsyncEmbeddingEngine = AsyncEngineArray .from_args ([engine_args ])[0 ]
6470 loop = asyncio .get_running_loop ()
6571 loop .create_task (self .engine .astart ())
66- self .mode = "embedding"
67- # rerank
68- for model_name in model_names :
69- if "rerank" in model_name :
70- self .mode = "rerank"
71- break
72- if self .mode == "rerank" :
73- logger .info ("正在使用 rerank 模型..." )
74- elif self .mode == "embedding" :
75- logger .info ("正在使用 embedding 模型..." )
72+ logger .info (f"正在使用 { self .mode } 模型..." )
7673 logger .info (f"模型:{ model_names [0 ]} " )
7774
7875 async def astart (self ):
@@ -83,7 +80,7 @@ async def get_embeddings(self, params):
8380 logger .info (f"worker_id: { self .worker_id } " )
8481 self .call_ct += 1
8582 ret = {"embedding" : [], "token_num" : 0 }
86- texts = params ["input" ]
83+ texts : list = params ["input" ]
8784 if self .mode == "embedding" :
8885 texts = list (map (lambda x : x .replace ("\n " , " " ), texts ))
8986 embeddings , usage = await self .engine .embed (sentences = texts )
@@ -105,6 +102,17 @@ async def get_embeddings(self, params):
105102 embedding = [
106103 [round (float (score ["relevance_score" ]), 6 )] for score in ranking
107104 ]
105+ elif self .mode == "image" :
106+ if (
107+ isinstance (texts [0 ], bytes )
108+ or "http" in texts [0 ]
109+ or "data:image" in texts [0 ]
110+ ):
111+ embeddings , usage = await self .engine .image_embed (images = texts )
112+ else :
113+ embeddings , usage = await self .engine .embed (sentences = texts )
114+
115+ embedding = [embedding .tolist () for embedding in embeddings ]
108116 ret ["embedding" ] = embedding
109117 ret ["token_num" ] = usage
110118 return ret
0 commit comments