Skip to content

Commit d80bb13

Browse files
committed
Merge branch 'main' of https://github.com/xerrors/Yuxi-Know into main
2 parents da945a5 + 49253cb commit d80bb13

File tree

10 files changed

+156
-69
lines changed

10 files changed

+156
-69
lines changed

docs/latest/changelog/roadmap.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
- 同名文件处理逻辑:遇到同名文件则在上传区域提示,是否删除旧文件
1616
- conversation 待修改为异步的版本
1717
- DBManager 需要将数据库修改为异步的aiosqlite或者异步mysql,缓存使用Redis存储
18-
- agent 状态中的文件区域,新增可以下载
1918

2019
### Bugs
2120
- 部分异常状态下,智能体的模型名称出现重叠[#279](https://github.com/xerrors/Yuxi-Know/issues/279)
2221
- DeepSeek 官方接口适配会出现问题
2322
- 目前的知识库的图片存在公开访问风险
23+
- 深度分析智能体需要考虑上下文超限的问题
2424

2525
### 新增
2626
- 优化知识库详情页面,更加简洁清晰
@@ -34,6 +34,8 @@
3434
- 新增自定义模型支持、新增 dashscope rerank/embeddings 模型的支持
3535
- 新增文档解析的图片支持,已支持 MinerU Officical、Docs、Markdown Zip格式
3636
- 新增暗色模式支持并调整整体 UI([#343](https://github.com/xerrors/Yuxi-Know/pull/343)
37+
- agent 状态中的文件区域,新增可以下载
38+
- 移除 Chroma 的支持,当前版本标记为移除
3739

3840
### 修复
3941
- 修复重排序模型实际未生效的问题

server/routers/knowledge_router.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ async def create_database(
8787
"""创建知识库"""
8888
logger.debug(
8989
f"Create database {database_name} with kb_type {kb_type}, "
90-
f"additional_params {additional_params}, llm_info {llm_info}"
90+
f"additional_params {additional_params}, llm_info {llm_info}, "
91+
f"embed_model_name {embed_model_name}"
9192
)
9293
try:
9394
additional_params = {**(additional_params or {})}
95+
additional_params["auto_generate_questions"] = False # 默认不生成问题
9496

9597
def normalize_reranker_config(kb: str, params: dict) -> None:
9698
reranker_cfg = params.get("reranker_config")
@@ -112,12 +114,12 @@ def normalize_reranker_config(kb: str, params: dict) -> None:
112114
if not isinstance(reranker_cfg, Mapping):
113115
raise HTTPException(status_code=400, detail="reranker_config must be an object")
114116

115-
enabled = bool(reranker_cfg.get("enabled", False))
117+
reranker_enabled = bool(reranker_cfg.get("enabled", False))
116118
model = (reranker_cfg.get("model") or "").strip()
117119
recall_top_k = max(1, int(reranker_cfg.get("recall_top_k", 50)))
118120
final_top_k = max(1, int(reranker_cfg.get("final_top_k", 10)))
119121

120-
if enabled:
122+
if reranker_enabled:
121123
if not model:
122124
raise HTTPException(status_code=400, detail="reranker_config.model is required when enabled")
123125
if model not in config.reranker_names:
@@ -132,7 +134,7 @@ def normalize_reranker_config(kb: str, params: dict) -> None:
132134
model = model if model in config.reranker_names else ""
133135

134136
params["reranker_config"] = {
135-
"enabled": enabled,
137+
"enabled": reranker_enabled,
136138
"model": model,
137139
"recall_top_k": recall_top_k,
138140
"final_top_k": final_top_k,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 057e2000be7b56823239815b0fe7c7fc0dbced96
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 6a0367834ea0fb5e5c94b9711e3e2756966789ea

src/config/static/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class EmbedModelInfo(BaseModel):
2929
dimension: int = Field(..., description="向量维度")
3030
base_url: str = Field(..., description="API 基础 URL")
3131
api_key: str = Field(..., description="API Key 或环境变量名")
32-
32+
model_id: str | None = Field(None, description="可选的模型 ID")
3333

3434
class RerankerInfo(BaseModel):
3535
"""重排序模型配置"""
@@ -158,42 +158,49 @@ class RerankerInfo(BaseModel):
158158

159159
DEFAULT_EMBED_MODELS: dict[str, EmbedModelInfo] = {
160160
"siliconflow/BAAI/bge-m3": EmbedModelInfo(
161+
model_id="siliconflow/BAAI/bge-m3",
161162
name="BAAI/bge-m3",
162163
dimension=1024,
163164
base_url="https://api.siliconflow.cn/v1/embeddings",
164165
api_key="SILICONFLOW_API_KEY",
165166
),
166167
"siliconflow/Pro/BAAI/bge-m3": EmbedModelInfo(
168+
model_id="siliconflow/Pro/BAAI/bge-m3",
167169
name="Pro/BAAI/bge-m3",
168170
dimension=1024,
169171
base_url="https://api.siliconflow.cn/v1/embeddings",
170172
api_key="SILICONFLOW_API_KEY",
171173
),
172174
"siliconflow/Qwen/Qwen3-Embedding-0.6B": EmbedModelInfo(
175+
model_id="siliconflow/Qwen/Qwen3-Embedding-0.6B",
173176
name="Qwen/Qwen3-Embedding-0.6B",
174177
dimension=1024,
175178
base_url="https://api.siliconflow.cn/v1/embeddings",
176179
api_key="SILICONFLOW_API_KEY",
177180
),
178181
"vllm/Qwen/Qwen3-Embedding-0.6B": EmbedModelInfo(
182+
model_id="vllm/Qwen/Qwen3-Embedding-0.6B",
179183
name="Qwen3-Embedding-0.6B",
180184
dimension=1024,
181185
base_url="http://localhost:8000/v1/embeddings",
182186
api_key="no_api_key",
183187
),
184188
"ollama/nomic-embed-text": EmbedModelInfo(
189+
model_id="ollama/nomic-embed-text",
185190
name="nomic-embed-text",
186191
dimension=768,
187192
base_url="http://localhost:11434/api/embed",
188193
api_key="no_api_key",
189194
),
190195
"ollama/bge-m3": EmbedModelInfo(
196+
model_id="ollama/bge-m3",
191197
name="bge-m3",
192198
dimension=1024,
193199
base_url="http://localhost:11434/api/embed",
194200
api_key="no_api_key",
195201
),
196202
"dashscope/text-embedding-v4": EmbedModelInfo(
203+
model_id="dashscope/text-embedding-v4",
197204
name="text-embedding-v4",
198205
dimension=1024,
199206
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1/embeddings",

src/knowledge/implementations/milvus.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections, db, utility
99

10+
from src import config
1011
from src.knowledge.base import KnowledgeBase
1112
from src.knowledge.indexing import process_file_to_markdown
1213
from src.knowledge.utils.kb_utils import (
@@ -91,10 +92,14 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
9192
"""创建 Milvus 集合"""
9293
logger.info(f"Creating Milvus collection for {db_id}")
9394

94-
if db_id not in self.databases_meta:
95+
if not (metadata := self.databases_meta.get(db_id)):
9596
raise ValueError(f"Database {db_id} not found")
9697

97-
embed_info = self.databases_meta[db_id].get("embed_info", {})
98+
# embed_info = metadata.get("embed_info", {})
99+
if not (embed_info := metadata.get("embed_info")):
100+
logger.error(f"Embedding info not found for database {db_id}, using default model")
101+
embed_info = config.embed_model_names[config.embed_model]
102+
98103
collection_name = db_id
99104

100105
try:
@@ -117,8 +122,8 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
117122

118123
except Exception:
119124
# 创建新集合
120-
embedding_dim = getattr(embed_info, "dimension", 1024) if embed_info else 1024
121-
model_name = getattr(embed_info, "name", "default") if embed_info else "default"
125+
embedding_dim = embed_info.get("dimension", 1024)
126+
model_name = embed_info.get("name", "default")
122127

123128
# 定义集合Schema
124129
fields = [
@@ -142,7 +147,7 @@ async def _create_kb_instance(self, db_id: str, kb_config: dict) -> Any:
142147
index_params = {"metric_type": "COSINE", "index_type": "IVF_FLAT", "params": {"nlist": 1024}}
143148
collection.create_index("embedding", index_params)
144149

145-
logger.info(f"Created new Milvus collection: {collection_name}")
150+
logger.info(f"Created new Milvus collection: {collection_name}: {model_name=}, {embedding_dim=}")
146151

147152
return collection
148153

@@ -154,25 +159,29 @@ async def _initialize_kb_instance(self, instance: Any) -> None:
154159
except Exception as e:
155160
logger.warning(f"Failed to load collection into memory: {e}")
156161

157-
def _get_async_embedding_function(self, embed_info: dict):
162+
def _get_async_embedding(self, embed_info: dict):
158163
"""获取 embedding 函数"""
164+
# 检查是否有 model_id 字段,优先使用 select_embedding_model
165+
if embed_info and "model_id" in embed_info:
166+
from src.models.embed import select_embedding_model
167+
return select_embedding_model(embed_info["model_id"])
168+
169+
# 使用原有的逻辑(兼容模式))
159170
config_dict = get_embedding_config(embed_info)
160-
embedding_model = OtherEmbedding(
171+
return OtherEmbedding(
161172
model=config_dict.get("model"),
162173
base_url=config_dict.get("base_url"),
163174
api_key=config_dict.get("api_key"),
164175
)
165176

177+
def _get_async_embedding_function(self, embed_info: dict):
178+
"""获取 embedding 函数"""
179+
embedding_model = self._get_async_embedding(embed_info)
166180
return partial(embedding_model.abatch_encode, batch_size=40)
167181

168182
def _get_embedding_function(self, embed_info: dict):
169183
"""获取 embedding 函数"""
170-
config_dict = get_embedding_config(embed_info)
171-
embedding_model = OtherEmbedding(
172-
model=config_dict.get("model"),
173-
base_url=config_dict.get("base_url"),
174-
api_key=config_dict.get("api_key"),
175-
)
184+
embedding_model = self._get_async_embedding(embed_info)
176185

177186
return partial(embedding_model.batch_encode, batch_size=40)
178187

src/knowledge/manager.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,12 @@ async def create_database(
246246
db_id = db_info["db_id"]
247247

248248
async with self._metadata_lock:
249-
# 准备 additional_params,包含 auto_generate_questions
250-
saved_params = kwargs.copy()
251-
saved_params["auto_generate_questions"] = False
252-
253249
self.global_databases_meta[db_id] = {
254250
"name": database_name,
255251
"description": description,
256252
"kb_type": kb_type,
257253
"created_at": utc_isoformat(),
258-
"additional_params": saved_params,
254+
"additional_params": kwargs.copy(),
259255
}
260256
self._save_global_metadata()
261257

src/knowledge/utils/kb_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,15 +247,23 @@ def get_embedding_config(embed_info: dict) -> dict:
247247

248248
try:
249249
if embed_info:
250-
# 处理 embed_info 可能是字典或 EmbedModelInfo 对象的情况
251-
if hasattr(embed_info, "name"):
250+
# 优先检查是否有 model_id 字段
251+
if "model_id" in embed_info:
252+
from src.models.embed import select_embedding_model
253+
254+
model = select_embedding_model(embed_info["model_id"])
255+
config_dict["model"] = model.model
256+
config_dict["api_key"] = model.api_key
257+
config_dict["base_url"] = model.base_url
258+
config_dict["dimension"] = getattr(model, "dimension", 1024)
259+
elif hasattr(embed_info, "name"):
252260
# EmbedModelInfo 对象
253261
config_dict["model"] = embed_info.name
254262
config_dict["api_key"] = os.getenv(embed_info.api_key) or embed_info.api_key
255263
config_dict["base_url"] = embed_info.base_url
256264
config_dict["dimension"] = embed_info.dimension
257265
else:
258-
# 字典形式
266+
# 字典形式(保持向后兼容)
259267
config_dict["model"] = embed_info["name"]
260268
config_dict["api_key"] = os.getenv(embed_info["api_key"]) or embed_info["api_key"]
261269
config_dict["base_url"] = embed_info["base_url"]

src/models/embed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class BaseEmbeddingModel(ABC):
14-
def __init__(self, model=None, name=None, dimension=None, url=None, base_url=None, api_key=None):
14+
def __init__(self, model=None, name=None, dimension=None, url=None, base_url=None, api_key=None, model_id=None):
1515
"""
1616
Args:
1717
model: 模型名称,冗余设计,同name
@@ -140,6 +140,7 @@ async def aencode(self, message: list[str] | str) -> list[list[float]]:
140140
payload = {"model": self.model, "input": message}
141141
async with httpx.AsyncClient() as client:
142142
try:
143+
print(f"\n\n\nOllama Embedding request: {payload}\n\n\n")
143144
response = await client.post(self.base_url, json=payload, timeout=60)
144145
response.raise_for_status()
145146
result = response.json()

0 commit comments

Comments
 (0)