File tree Expand file tree Collapse file tree 4 files changed +17
-2
lines changed Expand file tree Collapse file tree 4 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -249,6 +249,8 @@ def run(cls):
249
249
parser .add_argument ("--punc_model" , type = str , default = "" )
250
250
# log_level
251
251
parser .add_argument ("--log_level" , type = str , default = "WARNING" )
252
+ # task_type
253
+ parser .add_argument ("--task_type" , type = str , default = "auto" )
252
254
args = parser .parse_args ()
253
255
os .environ ["num_gpus" ] = str (args .num_gpus )
254
256
if args .backend == "vllm" :
@@ -276,6 +278,7 @@ def run(cls):
276
278
os .environ ["kv_cache_quant_policy" ] = args .kv_cache_quant_policy
277
279
os .environ ["dtype" ] = args .dtype
278
280
os .environ ["log_level" ] = args .log_level
281
+ os .environ ["task_type" ] = args .task_type
279
282
logger .remove (0 )
280
283
log_level = os .getenv ("log_level" , "WARNING" )
281
284
logger .add (sys .stderr , level = log_level )
Original file line number Diff line number Diff line change 3
3
from fastapi import HTTPException
4
4
import base64
5
5
import io
6
-
6
+ import os
7
7
from PIL .Image import Image
8
8
9
9
@@ -53,6 +53,14 @@ def get_embedding_mode(model_path: str):
53
53
from transformers import AutoConfig
54
54
from infinity_emb .inference .select_model import get_engine_type_from_config
55
55
56
+ task_type = os .environ .get ("task_type" , "auto" )
57
+ if task_type == "embedding" :
58
+ return "embedding"
59
+ elif task_type == "reranker" :
60
+ return "rerank"
61
+ elif task_type == "classify" :
62
+ return "classify"
63
+
56
64
model_config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
57
65
architectures = getattr (model_config , "architectures" , [])
58
66
if "JinaVLForRanking" in architectures :
Original file line number Diff line number Diff line change @@ -94,11 +94,12 @@ models:
94
94
- 2
95
95
96
96
- jina-reranker :
97
- # 多模态多语言的重排模型
97
+ # 多模态多语言的重排模型,这个模型task_type 只能是 auto
98
98
alias : null
99
99
enable : true
100
100
model_config :
101
101
model_name_or_path : /home/dev/model/jinaai/jina-reranker-m0/
102
+ task_type : auto # auto 、embedding 、 reranker 或者 classify 不设置这个参数,默认为 auto,自动识别可能会识别错误
102
103
model_type : embedding # 这里仅支持 embedding
103
104
work_mode : hf
104
105
device : gpu
@@ -112,6 +113,7 @@ models:
112
113
enable : true # false true
113
114
model_config :
114
115
model_name_or_path : /home/dev/model/aspire/acge_text_embedding
116
+ task_type : auto # auto 、embedding 、 reranker 或者 classify 不设置这个参数,默认为 auto,自动识别可能会识别错误
115
117
model_type : embedding_infinity # embedding_infinity/embedding
116
118
work_mode : hf
117
119
device : gpu # gpu / cpu
Original file line number Diff line number Diff line change @@ -167,6 +167,7 @@ def start_model_worker(config: dict):
167
167
)
168
168
vad_model = engine_config .get ("vad_model" , "" )
169
169
punc_model = engine_config .get ("punc_model" , "" )
170
+ task_type = engine_config .get ("task_type" , "auto" )
170
171
171
172
else :
172
173
logger .error (
@@ -252,6 +253,7 @@ def start_model_worker(config: dict):
252
253
+ f" --gpu_memory_utilization { gpu_memory_utilization } " # 占用GPU比例
253
254
+ f" --kv_cache_quant_policy { kv_cache_quant_policy } " # kv cache 量化策略
254
255
+ f" --log_level { log_level } " # 日志水平
256
+ + f" --task_type { task_type } " # 日志水平
255
257
)
256
258
# 处理为 None的情况
257
259
if lora :
You can’t perform that action at this time.
0 commit comments