Skip to content

Commit b8e83b7

Browse files
committed
fix(embedding): 修复嵌入模型配置中的兼容性问题
- 修复 get_embedding_config 中对 model/name 字段的兼容性处理 - 改进错误处理和日志记录 - 支持多种配置格式以保持向后兼容性 - 修复 lightrag.py 中 embedding 函数的模型名称获取逻辑
1 parent 3d6a38a commit b8e83b7

File tree

2 files changed

+53
-23
lines changed

2 files changed

+53
-23
lines changed

src/knowledge/implementations/lightrag.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ async def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwar
200200
def _get_embedding_func(self, embed_info: dict):
201201
"""获取 embedding 函数"""
202202
config_dict = get_embedding_config(embed_info)
203+
logger.debug(f"Embedding config dict: {config_dict}")
203204

204205
if config_dict.get("model_id") and config_dict["model_id"].startswith("ollama"):
205206
from lightrag.llm.ollama import ollama_embed
@@ -219,15 +220,22 @@ def _get_embedding_func(self, embed_info: dict):
219220
),
220221
)
221222

223+
# 尝试获取模型名称,支持多种键名以保持兼容性
224+
if "name" in config_dict and config_dict["name"]:
225+
model_name = config_dict["name"]
226+
elif "model" in config_dict and config_dict["model"]:
227+
model_name = config_dict["model"]
228+
else:
229+
raise ValueError(f"Neither 'name' nor 'model' found in config_dict or both are empty: {config_dict}")
222230
return EmbeddingFunc(
223231
embedding_dim=config_dict["dimension"],
224232
max_token_size=8192,
225233
func=lambda texts: openai_embed(
226234
texts=texts,
227-
model=config_dict["model"],
235+
model=model_name,
228236
api_key=config_dict["api_key"],
229237
base_url=config_dict["base_url"].replace("/embeddings", ""),
230-
),
238+
)
231239
)
232240

233241
async def add_content(self, db_id: str, items: list[str], params: dict | None = None) -> list[dict]:

src/knowledge/utils/kb_utils.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -289,30 +289,52 @@ def get_embedding_config(embed_info: dict) -> dict:
289289
Returns:
290290
dict: 标准化的嵌入配置
291291
"""
292-
config_dict = {}
293-
294292
try:
295-
if embed_info:
296-
# 优先检查是否有 model_id 字段
297-
if "model_id" in embed_info:
298-
return config.embed_model_names[embed_info["model_id"]].model_dump()
299-
elif hasattr(embed_info, "name") and isinstance(embed_info, EmbedModelInfo):
300-
return embed_info.model_dump()
301-
else:
302-
# 字典形式(保持向后兼容)
303-
config_dict["model"] = embed_info["name"]
304-
config_dict["api_key"] = os.getenv(embed_info["api_key"]) or embed_info["api_key"]
305-
config_dict["base_url"] = embed_info["base_url"]
306-
config_dict["dimension"] = embed_info.get("dimension", 1024)
307-
else:
308-
return config.embed_model_names[config.embed_model].model_dump()
293+
# 检查 embed_info 是否有效
294+
if not embed_info or ("model" not in embed_info and "name" not in embed_info):
295+
logger.error(f"Invalid embed_info: {embed_info}, using default embedding model config")
296+
raise ValueError("Invalid embed_info: must be a non-empty dictionary")
297+
298+
# 优先检查是否有 model_id 字段
299+
if "model_id" in embed_info and embed_info["model_id"]:
300+
logger.warning(f"Using model_id: {embed_info['model_id']}")
301+
config_dict = config.embed_model_names[embed_info["model_id"]].model_dump()
302+
config_dict["api_key"] = os.getenv(config_dict["api_key"]) or config_dict["api_key"]
303+
return config_dict
304+
305+
# 检查是否是 EmbedModelInfo 对象(在某些情况下可能直接传入对象)
306+
if hasattr(embed_info, "name") and isinstance(embed_info, EmbedModelInfo):
307+
logger.debug(f"Using EmbedModelInfo object: {embed_info.name}")
308+
config_dict = embed_info.model_dump()
309+
config_dict["api_key"] = os.getenv(config_dict["api_key"]) or config_dict["api_key"]
310+
return config_dict
311+
312+
# 字典形式(保持向后兼容)
313+
# 检查必需字段是否存在
314+
if not embed_info.get("name") or not embed_info.get("base_url"):
315+
logger.warning(f"embed_info missing required 'name' or 'base_url' field: {embed_info}, using default")
316+
raise ValueError("embed_info missing required 'name' or 'base_url' field")
317+
318+
config_dict = {
319+
"model": embed_info["name"],
320+
"api_key": os.getenv(embed_info["api_key"]) or embed_info["api_key"],
321+
"base_url": embed_info["base_url"],
322+
"dimension": embed_info.get("dimension", 1024)
323+
}
324+
logger.debug(f"Embedding config from dict: {config_dict}")
325+
return config_dict
309326

310327
except Exception as e:
311-
logger.error(f"Error in get_embedding_config: {e}, {embed_info}")
312-
raise ValueError(f"Error in get_embedding_config: {e}")
313-
314-
logger.debug(f"Embedding config: {config_dict}")
315-
return config_dict
328+
logger.error(f"Error in get_embedding_config: {e}, embed_info={embed_info}")
329+
# 返回默认配置作为fallback
330+
logger.warning("Falling back to default embedding model config")
331+
try:
332+
config_dict = config.embed_model_names[config.embed_model].model_dump()
333+
config_dict["api_key"] = os.getenv(config_dict["api_key"]) or config_dict["api_key"]
334+
return config_dict
335+
except Exception as fallback_error:
336+
logger.error(f"Failed to get default embedding config: {fallback_error}")
337+
raise ValueError(f"Failed to get embedding config and fallback failed: {e}")
316338

317339

318340
def is_minio_url(file_path: str) -> bool:

0 commit comments

Comments
 (0)