Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions astrbot/core/provider/sources/gemini_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from google import genai
from google.genai import types
from google.genai.errors import APIError
Expand All @@ -21,16 +20,26 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
self.provider_config = provider_config
self.provider_settings = provider_settings

api_key: str = provider_config["embedding_api_key"]
api_base: str = provider_config["embedding_api_base"]
# 校验必需的配置,避免静默使用空字符串导致后续难以排查的问题
api_key: str = provider_config.get("embedding_api_key", "")
if not api_key:
raise ValueError(
"Gemini embedding provider 配置错误: 缺少必需的 'embedding_api_key'"
)

api_base: str = provider_config.get("embedding_api_base", "")
Comment on lines +24 to +30
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Silently defaulting missing API config to empty strings can hide misconfiguration and lead to harder-to-debug failures.

Using get(..., "") avoids KeyError, but it also hides missing embedding_api_key / embedding_api_base and defers the failure to a less obvious point. Instead, either validate these fields and raise a clear configuration error when missing, or log a warning/error if they are empty so misconfigurations are immediately visible.

Suggested implementation:

        # 显式校验必需的配置,避免静默地使用空字符串导致后续难以排查的问题
        api_key = provider_config.get("embedding_api_key")
        api_base = provider_config.get("embedding_api_base")

        if not api_key or not api_base:
            raise ValueError(
                "Gemini embedding provider misconfigured: "
                "'embedding_api_key' and 'embedding_api_base' are required and must be non-empty."
            )

        timeout: int = int(provider_config.get("timeout", 20))
  1. If the project defines a custom configuration/validation exception (for example ConfigurationError or similar), replace the ValueError with that project-specific error type to be consistent with the rest of the codebase.
  2. If there is a standardized logging mechanism for configuration issues, you may also want to log this misconfiguration before raising, e.g. using a module-level logger.

timeout: int = int(provider_config.get("timeout", 20))

# GenAI SDK 的 timeout 单位是毫秒
http_options = types.HttpOptions(timeout=timeout * 1000)

if api_base:
api_base = api_base.removesuffix("/")
http_options.base_url = api_base

proxy = provider_config.get("proxy", "")
if proxy:
# 确保 proxy 配置包含协议头 (如 http://...)
http_options.async_client_args = {"proxy": proxy}
logger.info(f"[Gemini Embedding] 使用代理: {proxy}")

Expand All @@ -42,7 +51,10 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
)

async def get_embedding(self, text: str) -> list[float]:
"""获取文本的嵌入"""
# 获取文本的嵌入
if not text or not text.strip():
raise ValueError("输入文本不能为空")

try:
result = await self.client.models.embed_content(
model=self.model,
Expand All @@ -51,39 +63,60 @@ async def get_embedding(self, text: str) -> list[float]:
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None

# 使用显式检查替代 assert,防止生产环境下 -O 优化跳过 assert 校验
if not result.embeddings or not result.embeddings[0].values:
raise ValueError("API 响应异常:未返回有效的 embedding 数据")

return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
raise Exception(f"Gemini Embedding API请求失败: {e.message}") from e

async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
# 批量获取文本的嵌入
if not text:
return []

# 显式校验输入列表中的元素,防止传入空/空格文本导致后续接口或维度计算异常
if any(not s or not s.strip() for s in text):
raise ValueError("批量输入文本列表中不能包含空文本")

try:
# 构造 Content 列表以规避 gemini-embedding-2 批处理单返回 bug
contents = [
types.Content(parts=[types.Part.from_text(text=s)]) for s in text
]

result = await self.client.models.embed_content(
model=self.model,
contents=contents,
config=types.EmbedContentConfig(
output_dimensionality=self.get_dim(),
),
)
assert result.embeddings is not None

# 校验返回的数量是否和请求数量匹配
if not result.embeddings or len(result.embeddings) != len(text):
actual_len = len(result.embeddings) if result.embeddings else 0
raise ValueError(
f"API 响应异常:向量数量不匹配 (期望 {len(text)}, 实际 {actual_len})"
)

embeddings: list[list[float]] = []
for embedding in result.embeddings:
assert embedding.values is not None
if not embedding.values:
raise ValueError("API 响应异常:返回的部分 embedding 缺失 values")
embeddings.append(embedding.values)

return embeddings
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") from e

def get_dim(self) -> int:
"""获取向量的维度"""
# 获取向量的维度
return int(self.provider_config.get("embedding_dimensions", 768))

async def terminate(self):
if self.client:
# 释放资源
if getattr(self, "client", None):
await self.client.aclose()
Loading