diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index 71e9dadc9d..5cd6faeb1e 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -20,16 +20,26 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.provider_config = provider_config self.provider_settings = provider_settings - api_key: str = provider_config["embedding_api_key"] - api_base: str = provider_config["embedding_api_base"] + # 校验必需的配置,避免静默使用空字符串导致后续难以排查的问题 + api_key: str = provider_config.get("embedding_api_key", "") + if not api_key: + raise ValueError( + "Gemini embedding provider 配置错误: 缺少必需的 'embedding_api_key'" + ) + + api_base: str = provider_config.get("embedding_api_base", "") timeout: int = int(provider_config.get("timeout", 20)) + # GenAI SDK 的 timeout 单位是毫秒 http_options = types.HttpOptions(timeout=timeout * 1000) + if api_base: api_base = api_base.removesuffix("/") http_options.base_url = api_base + proxy = provider_config.get("proxy", "") if proxy: + # 确保 proxy 配置包含协议头 (如 http://...) http_options.async_client_args = {"proxy": proxy} logger.info(f"[Gemini Embedding] 使用代理: {proxy}") @@ -42,6 +52,9 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" + if not text or not text.strip(): + raise ValueError("输入文本不能为空") + try: result = await self.client.models.embed_content( model=self.model, @@ -50,18 +63,32 @@ async def get_embedding(self, text: str) -> list[float]: output_dimensionality=self.get_dim(), ), ) - assert result.embeddings is not None - assert result.embeddings[0].values is not None + + # 使用显式检查替代 assert,防止生产环境下 -O 优化跳过 assert 校验 + if not result.embeddings or not result.embeddings[0].values: + raise ValueError("API 响应异常:未返回有效的 embedding 数据") + return result.embeddings[0].values except APIError as e: - raise Exception(f"Gemini Embedding API请求失败: {e.message}") + # 使用 from e 保留原始调用栈 + raise Exception(f"Gemini Embedding API请求失败: {e.message}") from e async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" + # 即使是列表,也要确保不是 None 或空列表,与 get_embedding 保持逻辑严谨 + if not text: + raise ValueError("批量输入列表不能为空") + + # 显式校验输入列表中的元素,防止传入空/空格文本导致后续接口或维度计算异常 + if any(not s or not s.strip() for s in text): + raise ValueError("批量输入文本列表中不能包含空/空格文本") + try: + # 构造 Content 列表以规避 gemini-embedding-2 批处理单返回 bug contents = [ types.Content(parts=[types.Part.from_text(text=s)]) for s in text ] + result = await self.client.models.embed_content( model=self.model, contents=contents, @@ -69,20 +96,30 @@ async def get_embeddings(self, text: list[str]) -> list[list[float]]: output_dimensionality=self.get_dim(), ), ) - assert result.embeddings is not None + + # 校验返回的数量是否和请求数量匹配 + if not result.embeddings or len(result.embeddings) != len(text): + actual_len = len(result.embeddings) if result.embeddings else 0 + raise ValueError( + f"API 响应异常:向量数量不匹配 (期望 {len(text)}, 实际 {actual_len})" + ) embeddings: list[list[float]] = [] for embedding in result.embeddings: - assert embedding.values is not None + if not embedding.values: + raise ValueError("API 响应异常:返回的部分 embedding 缺失 values") embeddings.append(embedding.values) + return embeddings except APIError as e: - raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") + # 使用 from e 保留原始调用栈 + raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") from e def get_dim(self) -> int: """获取向量的维度""" return int(self.provider_config.get("embedding_dimensions", 768)) async def terminate(self): - if self.client: + """释放资源""" + if getattr(self, "client", None): await self.client.aclose()