77
88from pymilvus import Collection , CollectionSchema , DataType , FieldSchema , connections , db , utility
99
10+ from src import config
1011from src .knowledge .base import KnowledgeBase
1112from src .knowledge .indexing import process_file_to_markdown
1213from src .knowledge .utils .kb_utils import (
@@ -91,10 +92,14 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
9192 """创建 Milvus 集合"""
9293 logger .info (f"Creating Milvus collection for { db_id } " )
9394
94- if db_id not in self .databases_meta :
95+ if not ( metadata := self .databases_meta . get ( db_id )) :
9596 raise ValueError (f"Database { db_id } not found" )
9697
97- embed_info = self .databases_meta [db_id ].get ("embed_info" , {})
98+ # embed_info = metadata.get("embed_info", {})
99+ if not (embed_info := metadata .get ("embed_info" )):
100+ logger .error (f"Embedding info not found for database { db_id } , using default model" )
101+ embed_info = config .embed_model_names [config .embed_model ]
102+
98103 collection_name = db_id
99104
100105 try :
@@ -117,8 +122,8 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
117122
118123 except Exception :
119124 # 创建新集合
120- embedding_dim = getattr ( embed_info , "dimension" , 1024 ) if embed_info else 1024
121- model_name = getattr ( embed_info , "name" , "default" ) if embed_info else "default"
125+ embedding_dim = embed_info . get ( "dimension" , 1024 )
126+ model_name = embed_info . get ( "name" , "default" )
122127
123128 # 定义集合Schema
124129 fields = [
@@ -142,7 +147,7 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
142147 index_params = {"metric_type" : "COSINE" , "index_type" : "IVF_FLAT" , "params" : {"nlist" : 1024 }}
143148 collection .create_index ("embedding" , index_params )
144149
145- logger .info (f"Created new Milvus collection: { collection_name } " )
150+ logger .info (f"Created new Milvus collection: { collection_name } : { model_name = } , { embedding_dim = } " )
146151
147152 return collection
148153
@@ -154,25 +159,29 @@ async def _initialize_kb_instance(self, instance: Any) -> None:
154159 except Exception as e :
155160 logger .warning (f"Failed to load collection into memory: { e } " )
156161
157- def _get_async_embedding_function (self , embed_info : dict ):
162+ def _get_async_embedding (self , embed_info : dict ):
158163 """获取 embedding 函数"""
164+ # 检查是否有 model_id 字段,优先使用 select_embedding_model
165+ if embed_info and "model_id" in embed_info :
166+ from src .models .embed import select_embedding_model
167+ return select_embedding_model (embed_info ["model_id" ])
168+
169+ # 使用原有的逻辑(兼容模式))
159170 config_dict = get_embedding_config (embed_info )
160- embedding_model = OtherEmbedding (
171+ return OtherEmbedding (
161172 model = config_dict .get ("model" ),
162173 base_url = config_dict .get ("base_url" ),
163174 api_key = config_dict .get ("api_key" ),
164175 )
165176
177+ def _get_async_embedding_function (self , embed_info : dict ):
178+ """获取 embedding 函数"""
179+ embedding_model = self ._get_async_embedding (embed_info )
166180 return partial (embedding_model .abatch_encode , batch_size = 40 )
167181
168182 def _get_embedding_function (self , embed_info : dict ):
169183 """获取 embedding 函数"""
170- config_dict = get_embedding_config (embed_info )
171- embedding_model = OtherEmbedding (
172- model = config_dict .get ("model" ),
173- base_url = config_dict .get ("base_url" ),
174- api_key = config_dict .get ("api_key" ),
175- )
184+ embedding_model = self ._get_async_embedding (embed_info )
176185
177186 return partial (embedding_model .batch_encode , batch_size = 40 )
178187
0 commit comments