Skip to content

Commit 64f4b84

Browse files
committed
添加可以指定 任务类型 task_type
1 parent c94dc18 commit 64f4b84

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

gpt_server/model_worker/base/model_worker_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ def run(cls):
249249
parser.add_argument("--punc_model", type=str, default="")
250250
# log_level
251251
parser.add_argument("--log_level", type=str, default="WARNING")
252+
# task_type
253+
parser.add_argument("--task_type", type=str, default="auto")
252254
args = parser.parse_args()
253255
os.environ["num_gpus"] = str(args.num_gpus)
254256
if args.backend == "vllm":
@@ -276,6 +278,7 @@ def run(cls):
276278
os.environ["kv_cache_quant_policy"] = args.kv_cache_quant_policy
277279
os.environ["dtype"] = args.dtype
278280
os.environ["log_level"] = args.log_level
281+
os.environ["task_type"] = args.task_type
279282
logger.remove(0)
280283
log_level = os.getenv("log_level", "WARNING")
281284
logger.add(sys.stderr, level=log_level)

gpt_server/model_worker/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from fastapi import HTTPException
44
import base64
55
import io
6-
6+
import os
77
from PIL.Image import Image
88

99

@@ -53,6 +53,14 @@ def get_embedding_mode(model_path: str):
5353
from transformers import AutoConfig
5454
from infinity_emb.inference.select_model import get_engine_type_from_config
5555

56+
task_type = os.environ.get("task_type", "auto")
57+
if task_type == "embedding":
58+
return "embedding"
59+
elif task_type == "reranker":
60+
return "rerank"
61+
elif task_type == "classify":
62+
return "classify"
63+
5664
model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
5765
architectures = getattr(model_config, "architectures", [])
5866
if "JinaVLForRanking" in architectures:

gpt_server/script/config_example.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ models:
9494
- 2
9595

9696
- jina-reranker:
97-
# 多模态多语言的重排模型
97+
# 多模态多语言的重排模型,这个模型task_type 只能是 auto
9898
alias: null
9999
enable: true
100100
model_config:
101101
model_name_or_path: /home/dev/model/jinaai/jina-reranker-m0/
102+
task_type: auto # auto 、embedding 、 reranker 或者 classify 不设置这个参数,默认为 auto,自动识别可能会识别错误
102103
model_type: embedding # 这里仅支持 embedding
103104
work_mode: hf
104105
device: gpu
@@ -112,6 +113,7 @@ models:
112113
enable: true # false true
113114
model_config:
114115
model_name_or_path: /home/dev/model/aspire/acge_text_embedding
116+
task_type: auto # auto 、embedding 、 reranker 或者 classify 不设置这个参数,默认为 auto,自动识别可能会识别错误
115117
model_type: embedding_infinity # embedding_infinity/embedding
116118
work_mode: hf
117119
device: gpu # gpu / cpu

gpt_server/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def start_model_worker(config: dict):
167167
)
168168
vad_model = engine_config.get("vad_model", "")
169169
punc_model = engine_config.get("punc_model", "")
170+
task_type = engine_config.get("task_type", "auto")
170171

171172
else:
172173
logger.error(
@@ -252,6 +253,7 @@ def start_model_worker(config: dict):
252253
+ f" --gpu_memory_utilization {gpu_memory_utilization}" # 占用GPU比例
253254
+ f" --kv_cache_quant_policy {kv_cache_quant_policy}" # kv cache 量化策略
254255
+ f" --log_level {log_level}" # 日志水平
256+
+ f" --task_type {task_type}" # 日志水平
255257
)
256258
# 处理为 None的情况
257259
if lora:

0 commit comments

Comments
 (0)