Skip to content

Commit 9753b41

Browse files
committed
refactor(llm): improve embedding dimension detection with fallback chain
- Use MODEL_CONFIGS as primary source (we control it) - Fall back to LiteLLM registry for unknown models - Catch specific LiteLLMNotFoundError instead of broad Exception - Warn at construction time if dimensions cannot be detected - Move detection to __init__ for early feedback
1 parent 15de5b8 commit 9753b41

File tree

1 file changed

+53
-11
lines changed

1 file changed

+53
-11
lines changed

agent_memory_server/llm/embeddings.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from typing import Any
1212

1313
from langchain_core.embeddings import Embeddings
14-
from litellm import aembedding, embedding, get_model_info
14+
from litellm import aembedding, embedding
15+
from litellm.exceptions import NotFoundError as LiteLLMNotFoundError
16+
from litellm.utils import get_model_info
1517

1618

1719
logger = logging.getLogger(__name__)
@@ -71,27 +73,67 @@ def __init__(
7173
self.model = model
7274
self.api_base = api_base
7375
self.api_key = api_key
74-
self._dimensions = dimensions
7576
self._extra_kwargs = kwargs
7677

77-
@property
78-
def dimensions(self) -> int | None:
79-
"""Get embedding dimensions, auto-detecting if not set."""
80-
if self._dimensions is not None:
81-
return self._dimensions
78+
# Detect dimensions at construction time
79+
self._dimensions = dimensions or self._detect_dimensions()
80+
81+
if self._dimensions is None:
82+
logger.warning(
83+
"Could not determine embedding dimensions for model %r. "
84+
"Vector stores requiring explicit dimensions may fail. "
85+
"Consider specifying dimensions explicitly.",
86+
self.model,
87+
)
88+
89+
def _detect_dimensions(self) -> int | None:
90+
"""
91+
Detect embedding dimensions using fallback chain.
92+
93+
Order of precedence:
94+
1. Our MODEL_CONFIGS (most reliable, we control it)
95+
2. LiteLLM's model registry (fallback for unknown models)
8296
83-
# Try to get from LiteLLM model info
97+
Returns:
98+
Detected dimensions or None if unknown.
99+
"""
100+
# Import here to avoid circular dependency
101+
from agent_memory_server.config import MODEL_CONFIGS
102+
103+
# 1. Check our known config first (most reliable)
104+
if self.model in MODEL_CONFIGS:
105+
dims = MODEL_CONFIGS[self.model].embedding_dimensions
106+
logger.debug(
107+
"Detected dimensions=%d for model %r from MODEL_CONFIGS",
108+
dims,
109+
self.model,
110+
)
111+
return dims
112+
113+
# 2. Try LiteLLM's registry as fallback
84114
try:
85115
info = get_model_info(self.model)
86116
dims = info.get("output_vector_size")
87117
if dims is not None:
88-
self._dimensions = dims
118+
logger.debug(
119+
"Detected dimensions=%d for model %r from LiteLLM registry",
120+
dims,
121+
self.model,
122+
)
89123
return dims
90-
except Exception:
91-
pass
124+
except LiteLLMNotFoundError:
125+
logger.debug(
126+
"Model %r not found in LiteLLM registry",
127+
self.model,
128+
)
92129

93130
return None
94131

132+
@property
133+
def dimensions(self) -> int | None:
134+
"""Get embedding dimensions."""
135+
return self._dimensions
136+
95137
def _build_call_kwargs(self, input_texts: list[str]) -> dict[str, Any]:
96138
"""Build kwargs for LiteLLM embedding call."""
97139
kwargs: dict[str, Any] = {

0 commit comments

Comments
 (0)