@@ -39,9 +39,45 @@ async def load_base64_or_url(base64_or_url):
39
39
return bytes_io
40
40
41
41
42
+ def get_embedding_mode (model_path : str ):
43
+ from infinity_emb import EngineArgs
44
+ from transformers import AutoConfig
45
+ from infinity_emb .inference .select_model import get_engine_type_from_config
46
+
47
+ model_config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
48
+ model_type_text = getattr (
49
+ getattr (model_config , "text_config" , {}), "model_type" , None
50
+ )
51
+ model_type_vison = getattr (
52
+ getattr (model_config , "vision_config" , {}), "model_type" , None
53
+ )
54
+ print (model_type_vison , model_type_text )
55
+ model_type = model_type_vison or model_type_text
56
+
57
+ mode = "embedding"
58
+ engine_args = EngineArgs (
59
+ model_name_or_path = model_path ,
60
+ engine = "torch" ,
61
+ embedding_dtype = "float32" ,
62
+ dtype = "float32" ,
63
+ bettertransformer = True ,
64
+ )
65
+ engine_type = get_engine_type_from_config (engine_args )
66
+ engine_type_str = str (engine_type )
67
+
68
+ if "EmbedderEngine" in engine_type_str :
69
+ mode = "embedding"
70
+ elif "RerankEngine" in engine_type_str :
71
+ mode = "rerank"
72
+ elif "ImageEmbedEngine" in engine_type_str :
73
+ mode = model_type or "image"
74
+ elif "PredictEngine" in engine_type_str :
75
+ mode = "classify"
76
+ return mode
77
+
78
+
42
79
if __name__ == "__main__" :
43
80
44
81
# 示例用法
45
- data_url = "..."
46
- pure_base64 = extract_base64 (data_url )
47
- print (pure_base64 ) # 输出: iVBORw0KGgoAAAANSUhEUg...
82
+ r = get_embedding_mode ("BAAI/BGE-VL-MLLM-S1" )
83
+ print (r )
0 commit comments