|
11 | 11 | from typing import Any |
12 | 12 |
|
13 | 13 | from langchain_core.embeddings import Embeddings |
14 | | -from litellm import aembedding, embedding, get_model_info |
| 14 | +from litellm import aembedding, embedding |
| 15 | +from litellm.exceptions import NotFoundError as LiteLLMNotFoundError |
| 16 | +from litellm.utils import get_model_info |
15 | 17 |
|
16 | 18 |
|
17 | 19 | logger = logging.getLogger(__name__) |
@@ -71,27 +73,67 @@ def __init__( |
71 | 73 | self.model = model |
72 | 74 | self.api_base = api_base |
73 | 75 | self.api_key = api_key |
74 | | - self._dimensions = dimensions |
75 | 76 | self._extra_kwargs = kwargs |
76 | 77 |
|
77 | | - @property |
78 | | - def dimensions(self) -> int | None: |
79 | | - """Get embedding dimensions, auto-detecting if not set.""" |
80 | | - if self._dimensions is not None: |
81 | | - return self._dimensions |
| 78 | + # Detect dimensions at construction time |
| 79 | + self._dimensions = dimensions or self._detect_dimensions() |
| 80 | + |
| 81 | + if self._dimensions is None: |
| 82 | + logger.warning( |
| 83 | + "Could not determine embedding dimensions for model %r. " |
| 84 | + "Vector stores requiring explicit dimensions may fail. " |
| 85 | + "Consider specifying dimensions explicitly.", |
| 86 | + self.model, |
| 87 | + ) |
| 88 | + |
| 89 | + def _detect_dimensions(self) -> int | None: |
| 90 | + """ |
| 91 | + Detect embedding dimensions using fallback chain. |
| 92 | +
|
| 93 | + Order of precedence: |
| 94 | + 1. Our MODEL_CONFIGS (most reliable, we control it) |
| 95 | + 2. LiteLLM's model registry (fallback for unknown models) |
82 | 96 |
|
83 | | - # Try to get from LiteLLM model info |
| 97 | + Returns: |
| 98 | + Detected dimensions or None if unknown. |
| 99 | + """ |
| 100 | + # Import here to avoid circular dependency |
| 101 | + from agent_memory_server.config import MODEL_CONFIGS |
| 102 | + |
| 103 | + # 1. Check our known config first (most reliable) |
| 104 | + if self.model in MODEL_CONFIGS: |
| 105 | + dims = MODEL_CONFIGS[self.model].embedding_dimensions |
| 106 | + logger.debug( |
| 107 | + "Detected dimensions=%d for model %r from MODEL_CONFIGS", |
| 108 | + dims, |
| 109 | + self.model, |
| 110 | + ) |
| 111 | + return dims |
| 112 | + |
| 113 | + # 2. Try LiteLLM's registry as fallback |
84 | 114 | try: |
85 | 115 | info = get_model_info(self.model) |
86 | 116 | dims = info.get("output_vector_size") |
87 | 117 | if dims is not None: |
88 | | - self._dimensions = dims |
| 118 | + logger.debug( |
| 119 | + "Detected dimensions=%d for model %r from LiteLLM registry", |
| 120 | + dims, |
| 121 | + self.model, |
| 122 | + ) |
89 | 123 | return dims |
90 | | - except Exception: |
91 | | - pass |
| 124 | + except LiteLLMNotFoundError: |
| 125 | + logger.debug( |
| 126 | + "Model %r not found in LiteLLM registry", |
| 127 | + self.model, |
| 128 | + ) |
92 | 129 |
|
93 | 130 | return None |
94 | 131 |
|
| 132 | + @property |
| 133 | + def dimensions(self) -> int | None: |
| 134 | + """Get embedding dimensions.""" |
| 135 | + return self._dimensions |
| 136 | + |
95 | 137 | def _build_call_kwargs(self, input_texts: list[str]) -> dict[str, Any]: |
96 | 138 | """Build kwargs for LiteLLM embedding call.""" |
97 | 139 | kwargs: dict[str, Any] = { |
|
0 commit comments