Skip to content

Commit c94dc18

Browse files
committed
优化embedding 代码
1 parent d41c8b2 commit c94dc18

File tree

3 files changed

+3
-11
lines changed

3 files changed

+3
-11
lines changed

gpt_server/model_worker/embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
)
6969
logger.warning("正在使用 embedding 模型...")
7070
logger.warning(f"模型:{model_names[0]}")
71+
logger.warning(f"正在使用 {self.mode} 模型...")
7172

7273
async def get_embeddings(self, params):
7374
self.call_ct += 1

gpt_server/model_worker/embedding_infinity.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from infinity_emb import AsyncEngineArray, EngineArgs, AsyncEmbeddingEngine
77
from infinity_emb.inference.select_model import get_engine_type_from_config
88
from gpt_server.model_worker.base.model_worker_base import ModelWorkerBase
9+
from gpt_server.model_worker.utils import get_embedding_mode
910

1011
label_to_category = {
1112
"S": "sexual",
@@ -58,16 +59,7 @@ def __init__(
5859
device=device,
5960
bettertransformer=bettertransformer,
6061
)
61-
engine_type = get_engine_type_from_config(engine_args)
62-
engine_type_str = str(engine_type)
63-
if "EmbedderEngine" in engine_type_str:
64-
self.mode = "embedding"
65-
elif "RerankEngine" in engine_type_str:
66-
self.mode = "rerank"
67-
elif "ImageEmbedEngine" in engine_type_str:
68-
self.mode = "image"
69-
elif "PredictEngine" in engine_type_str:
70-
self.mode = "classify"
62+
self.mode = get_embedding_mode(model_path=model_path)
7163
self.engine: AsyncEmbeddingEngine = AsyncEngineArray.from_args([engine_args])[0]
7264
loop = asyncio.get_running_loop()
7365
loop.create_task(self.engine.astart())

gpt_server/model_worker/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def get_embedding_mode(model_path: str):
6868
# print(model_type_vison, model_type_text)
6969
model_type = model_type_text
7070

71-
mode = "embedding"
7271
engine_args = EngineArgs(
7372
model_name_or_path=model_path,
7473
engine="torch",

0 commit comments

Comments
 (0)