Skip to content

Commit fd5fca0

Browse files
committed
Merge branch 'main' of https://github.com/xerrors/Yuxi-Know
2 parents 7750a5f + ebaccea commit fd5fca0

File tree

6 files changed

+54
-30
lines changed

6 files changed

+54
-30
lines changed

src/agents/chatbot/tools.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ async def text_to_img_qwen(text: str) -> str:
4242

4343
file_name = f"{uuid.uuid4()}.jpg"
4444
image_url = await aupload_file_to_minio(
45-
bucket_name="generated-images",
46-
file_name=file_name,
47-
data=file_data,
48-
file_extension="jpg"
45+
bucket_name="generated-images", file_name=file_name, data=file_data, file_extension="jpg"
4946
)
5047
logger.info(f"Image uploaded. URL: {image_url}")
5148
return image_url

src/knowledge/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,12 @@ def create_database(
143143

144144
# 创建数据库记录
145145
# 确保 Pydantic 模型被转换为字典,以便 JSON 序列化
146+
embed_info_dump = embed_info.model_dump() if hasattr(embed_info, "model_dump") else embed_info
146147
self.databases_meta[db_id] = {
147148
"name": database_name,
148149
"description": description,
149150
"kb_type": self.kb_type,
150-
"embed_info": embed_info.model_dump() if hasattr(embed_info, "model_dump") else embed_info,
151+
"embed_info": embed_info_dump,
151152
"llm_info": llm_info.model_dump() if hasattr(llm_info, "model_dump") else llm_info,
152153
"metadata": kwargs,
153154
"created_at": utc_isoformat(),

src/knowledge/implementations/lightrag.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,27 @@ def _get_embedding_func(self, embed_info: dict):
201201
"""获取 embedding 函数"""
202202
config_dict = get_embedding_config(embed_info)
203203

204+
if config_dict["model_id"].startswith("ollama"):
205+
from lightrag.llm.ollama import ollama_embed
206+
207+
from src.utils import get_docker_safe_url
208+
209+
host = get_docker_safe_url(config_dict["base_url"].replace("/api/embed", ""))
210+
logger.debug(f"Ollama host: {host}")
211+
return EmbeddingFunc(
212+
embedding_dim=config_dict["dimension"],
213+
max_token_size=8192,
214+
func=lambda texts: ollama_embed(
215+
texts=texts,
216+
embed_model=config_dict["name"],
217+
api_key=config_dict["api_key"],
218+
host=host,
219+
),
220+
)
221+
204222
return EmbeddingFunc(
205223
embedding_dim=config_dict["dimension"],
206-
max_token_size=4096,
224+
max_token_size=8192,
207225
func=lambda texts: openai_embed(
208226
texts=texts,
209227
model=config_dict["model"],
@@ -365,12 +383,37 @@ async def aquery(self, query_text: str, db_id: str, **kwargs) -> str:
365383
raise ValueError(f"Database {db_id} not found")
366384

367385
try:
386+
# QueryParam 支持的参数列表
387+
valid_params = {
388+
"mode",
389+
"only_need_context",
390+
"only_need_prompt",
391+
"response_type",
392+
"stream",
393+
"top_k",
394+
"chunk_top_k",
395+
"max_entity_tokens",
396+
"max_relation_tokens",
397+
"max_total_tokens",
398+
"hl_keywords",
399+
"ll_keywords",
400+
"conversation_history",
401+
"history_turns",
402+
"model_func",
403+
"user_prompt",
404+
"enable_rerank",
405+
"include_references",
406+
}
407+
408+
# 过滤 kwargs,只保留 QueryParam 支持的参数
409+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
410+
368411
# 设置查询参数
369412
params_dict = {
370413
"mode": "mix",
371414
"only_need_context": True,
372415
"top_k": 10,
373-
} | kwargs
416+
} | filtered_kwargs
374417
param = QueryParam(**params_dict)
375418

376419
# 执行查询

src/knowledge/utils/kb_utils.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from langchain_text_splitters import MarkdownTextSplitter
88

99
from src import config
10+
from src.config.static.models import EmbedModelInfo
1011
from src.utils import hashstr, logger
1112
from src.utils.datetime_utils import utc_isoformat
1213

@@ -293,33 +294,17 @@ def get_embedding_config(embed_info: dict) -> dict:
293294
if embed_info:
294295
# 优先检查是否有 model_id 字段
295296
if "model_id" in embed_info:
296-
from src.models.embed import select_embedding_model
297-
298-
model = select_embedding_model(embed_info["model_id"])
299-
config_dict["model"] = model.model
300-
config_dict["api_key"] = model.api_key
301-
config_dict["base_url"] = model.base_url
302-
config_dict["dimension"] = getattr(model, "dimension", 1024)
303-
elif hasattr(embed_info, "name"):
304-
# EmbedModelInfo 对象
305-
config_dict["model"] = embed_info.name
306-
config_dict["api_key"] = os.getenv(embed_info.api_key) or embed_info.api_key
307-
config_dict["base_url"] = embed_info.base_url
308-
config_dict["dimension"] = embed_info.dimension
297+
return config.embed_model_names[embed_info["model_id"]].model_dump()
298+
elif hasattr(embed_info, "name") and isinstance(embed_info, EmbedModelInfo):
299+
return embed_info.model_dump()
309300
else:
310301
# 字典形式(保持向后兼容)
311302
config_dict["model"] = embed_info["name"]
312303
config_dict["api_key"] = os.getenv(embed_info["api_key"]) or embed_info["api_key"]
313304
config_dict["base_url"] = embed_info["base_url"]
314305
config_dict["dimension"] = embed_info.get("dimension", 1024)
315306
else:
316-
from src.models import select_embedding_model
317-
318-
default_model = select_embedding_model(config.embed_model)
319-
config_dict["model"] = default_model.model
320-
config_dict["api_key"] = default_model.api_key
321-
config_dict["base_url"] = default_model.base_url
322-
config_dict["dimension"] = getattr(default_model, "dimension", 1024)
307+
return config.embed_model_names[config.embed_model].model_dump()
323308

324309
except Exception as e:
325310
logger.error(f"Error in get_embedding_config: {e}, {embed_info}")

src/storage/minio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
# 导出核心功能
7-
from .client import MinIOClient, StorageError, UploadResult, get_minio_client, aupload_file_to_minio
7+
from .client import MinIOClient, StorageError, UploadResult, aupload_file_to_minio, get_minio_client
88
from .utils import generate_unique_filename, get_file_size
99

1010
# 为了向后兼容,导出常用的函数

src/storage/minio/client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,6 @@ def get_minio_client() -> MinIOClient:
288288
return _default_client
289289

290290

291-
292-
293291
async def aupload_file_to_minio(bucket_name: str, file_name: str, data: bytes, file_extension: str) -> str:
294292
"""
295293
通过字节上传文件到 MinIO的异步接口,根据输入的file_extension确定文件格式,并返回资源url

0 commit comments

Comments
 (0)