|
7 | 7 | from langchain_text_splitters import MarkdownTextSplitter |
8 | 8 |
|
9 | 9 | from src import config |
| 10 | +from src.config.static.models import EmbedModelInfo |
10 | 11 | from src.utils import hashstr, logger |
11 | 12 | from src.utils.datetime_utils import utc_isoformat |
12 | 13 |
|
@@ -249,33 +250,17 @@ def get_embedding_config(embed_info: dict) -> dict: |
249 | 250 | if embed_info: |
250 | 251 | # 优先检查是否有 model_id 字段 |
251 | 252 | if "model_id" in embed_info: |
252 | | - from src.models.embed import select_embedding_model |
253 | | - |
254 | | - model = select_embedding_model(embed_info["model_id"]) |
255 | | - config_dict["model"] = model.model |
256 | | - config_dict["api_key"] = model.api_key |
257 | | - config_dict["base_url"] = model.base_url |
258 | | - config_dict["dimension"] = getattr(model, "dimension", 1024) |
259 | | - elif hasattr(embed_info, "name"): |
260 | | - # EmbedModelInfo 对象 |
261 | | - config_dict["model"] = embed_info.name |
262 | | - config_dict["api_key"] = os.getenv(embed_info.api_key) or embed_info.api_key |
263 | | - config_dict["base_url"] = embed_info.base_url |
264 | | - config_dict["dimension"] = embed_info.dimension |
| 253 | + return config.embed_model_names[embed_info["model_id"]].model_dump() |
| 254 | + elif hasattr(embed_info, "name") and isinstance(embed_info, EmbedModelInfo): |
| 255 | + return embed_info.model_dump() |
265 | 256 | else: |
266 | 257 | # 字典形式(保持向后兼容) |
267 | 258 | config_dict["model"] = embed_info["name"] |
268 | 259 | config_dict["api_key"] = os.getenv(embed_info["api_key"]) or embed_info["api_key"] |
269 | 260 | config_dict["base_url"] = embed_info["base_url"] |
270 | 261 | config_dict["dimension"] = embed_info.get("dimension", 1024) |
271 | 262 | else: |
272 | | - from src.models import select_embedding_model |
273 | | - |
274 | | - default_model = select_embedding_model(config.embed_model) |
275 | | - config_dict["model"] = default_model.model |
276 | | - config_dict["api_key"] = default_model.api_key |
277 | | - config_dict["base_url"] = default_model.base_url |
278 | | - config_dict["dimension"] = getattr(default_model, "dimension", 1024) |
| 263 | + return config.embed_model_names[config.embed_model].model_dump() |
279 | 264 |
|
280 | 265 | except Exception as e: |
281 | 266 | logger.error(f"Error in get_embedding_config: {e}, {embed_info}") |
|
0 commit comments