Skip to content

Commit ebaccea

Browse files
committed
refactor(knowledge): 优化lightrag的 ollama embedding配置处理并添加查询参数过滤
重构get_embedding_config函数,直接使用model_dump返回配置信息 在LightRagKB中增加Ollama embedding支持并调整token大小限制 添加查询参数过滤逻辑,只保留有效参数
1 parent 6003563 commit ebaccea

File tree

6 files changed

+54
-30
lines changed

6 files changed

+54
-30
lines changed

src/agents/chatbot/tools.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ async def text_to_img_qwen(text: str) -> str:
4242

4343
file_name = f"{uuid.uuid4()}.jpg"
4444
image_url = await aupload_file_to_minio(
45-
bucket_name="generated-images",
46-
file_name=file_name,
47-
data=file_data,
48-
file_extension="jpg"
45+
bucket_name="generated-images", file_name=file_name, data=file_data, file_extension="jpg"
4946
)
5047
logger.info(f"Image uploaded. URL: {image_url}")
5148
return image_url

src/knowledge/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,12 @@ def create_database(
143143

144144
# 创建数据库记录
145145
# 确保 Pydantic 模型被转换为字典,以便 JSON 序列化
146+
embed_info_dump = embed_info.model_dump() if hasattr(embed_info, "model_dump") else embed_info
146147
self.databases_meta[db_id] = {
147148
"name": database_name,
148149
"description": description,
149150
"kb_type": self.kb_type,
150-
"embed_info": embed_info.model_dump() if hasattr(embed_info, "model_dump") else embed_info,
151+
"embed_info": embed_info_dump,
151152
"llm_info": llm_info.model_dump() if hasattr(llm_info, "model_dump") else llm_info,
152153
"metadata": kwargs,
153154
"created_at": utc_isoformat(),

src/knowledge/implementations/lightrag.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,27 @@ def _get_embedding_func(self, embed_info: dict):
201201
"""获取 embedding 函数"""
202202
config_dict = get_embedding_config(embed_info)
203203

204+
if config_dict["model_id"].startswith("ollama"):
205+
from lightrag.llm.ollama import ollama_embed
206+
207+
from src.utils import get_docker_safe_url
208+
209+
host = get_docker_safe_url(config_dict["base_url"].replace("/api/embed", ""))
210+
logger.debug(f"Ollama host: {host}")
211+
return EmbeddingFunc(
212+
embedding_dim=config_dict["dimension"],
213+
max_token_size=8192,
214+
func=lambda texts: ollama_embed(
215+
texts=texts,
216+
embed_model=config_dict["name"],
217+
api_key=config_dict["api_key"],
218+
host=host,
219+
),
220+
)
221+
204222
return EmbeddingFunc(
205223
embedding_dim=config_dict["dimension"],
206-
max_token_size=4096,
224+
max_token_size=8192,
207225
func=lambda texts: openai_embed(
208226
texts=texts,
209227
model=config_dict["model"],
@@ -365,12 +383,37 @@ async def aquery(self, query_text: str, db_id: str, **kwargs) -> str:
365383
raise ValueError(f"Database {db_id} not found")
366384

367385
try:
386+
# QueryParam 支持的参数列表
387+
valid_params = {
388+
"mode",
389+
"only_need_context",
390+
"only_need_prompt",
391+
"response_type",
392+
"stream",
393+
"top_k",
394+
"chunk_top_k",
395+
"max_entity_tokens",
396+
"max_relation_tokens",
397+
"max_total_tokens",
398+
"hl_keywords",
399+
"ll_keywords",
400+
"conversation_history",
401+
"history_turns",
402+
"model_func",
403+
"user_prompt",
404+
"enable_rerank",
405+
"include_references",
406+
}
407+
408+
# 过滤 kwargs,只保留 QueryParam 支持的参数
409+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
410+
368411
# 设置查询参数
369412
params_dict = {
370413
"mode": "mix",
371414
"only_need_context": True,
372415
"top_k": 10,
373-
} | kwargs
416+
} | filtered_kwargs
374417
param = QueryParam(**params_dict)
375418

376419
# 执行查询

src/knowledge/utils/kb_utils.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from langchain_text_splitters import MarkdownTextSplitter
88

99
from src import config
10+
from src.config.static.models import EmbedModelInfo
1011
from src.utils import hashstr, logger
1112
from src.utils.datetime_utils import utc_isoformat
1213

@@ -249,33 +250,17 @@ def get_embedding_config(embed_info: dict) -> dict:
249250
if embed_info:
250251
# 优先检查是否有 model_id 字段
251252
if "model_id" in embed_info:
252-
from src.models.embed import select_embedding_model
253-
254-
model = select_embedding_model(embed_info["model_id"])
255-
config_dict["model"] = model.model
256-
config_dict["api_key"] = model.api_key
257-
config_dict["base_url"] = model.base_url
258-
config_dict["dimension"] = getattr(model, "dimension", 1024)
259-
elif hasattr(embed_info, "name"):
260-
# EmbedModelInfo 对象
261-
config_dict["model"] = embed_info.name
262-
config_dict["api_key"] = os.getenv(embed_info.api_key) or embed_info.api_key
263-
config_dict["base_url"] = embed_info.base_url
264-
config_dict["dimension"] = embed_info.dimension
253+
return config.embed_model_names[embed_info["model_id"]].model_dump()
254+
elif hasattr(embed_info, "name") and isinstance(embed_info, EmbedModelInfo):
255+
return embed_info.model_dump()
265256
else:
266257
# 字典形式(保持向后兼容)
267258
config_dict["model"] = embed_info["name"]
268259
config_dict["api_key"] = os.getenv(embed_info["api_key"]) or embed_info["api_key"]
269260
config_dict["base_url"] = embed_info["base_url"]
270261
config_dict["dimension"] = embed_info.get("dimension", 1024)
271262
else:
272-
from src.models import select_embedding_model
273-
274-
default_model = select_embedding_model(config.embed_model)
275-
config_dict["model"] = default_model.model
276-
config_dict["api_key"] = default_model.api_key
277-
config_dict["base_url"] = default_model.base_url
278-
config_dict["dimension"] = getattr(default_model, "dimension", 1024)
263+
return config.embed_model_names[config.embed_model].model_dump()
279264

280265
except Exception as e:
281266
logger.error(f"Error in get_embedding_config: {e}, {embed_info}")

src/storage/minio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# 导出核心功能
7-
from .client import MinIOClient, StorageError, UploadResult, get_minio_client, aupload_file_to_minio
7+
from .client import MinIOClient, StorageError, UploadResult, aupload_file_to_minio, get_minio_client
88
from .utils import generate_unique_filename, get_file_size
99

1010
# 为了向后兼容,导出常用的函数

src/storage/minio/client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,6 @@ def get_minio_client() -> MinIOClient:
288288
return _default_client
289289

290290

291-
292-
293291
async def aupload_file_to_minio(bucket_name: str, file_name: str, data: bytes, file_extension: str) -> str:
294292
"""
295293
通过字节上传文件到 MinIO的异步接口,根据输入的file_extension确定文件格式,并返回资源url

0 commit comments

Comments
 (0)