|
22 | 22 | def get_embedding(text: str, timeout: int = 60, max_retries: int = 3) -> Optional[list[float]]: |
23 | 23 | """ |
24 | 24 | Generate embedding using Ollama with retry logic. |
25 | | - |
| 25 | +
|
26 | 26 | Args: |
27 | 27 | text: Text to embed |
28 | 28 | timeout: Request timeout in seconds |
29 | 29 | max_retries: Maximum retry attempts on failure |
30 | | - |
| 30 | +
|
31 | 31 | Returns: |
32 | 32 | Embedding vector or None if failed |
33 | 33 | """ |
34 | 34 | if not text or len(text.strip()) == 0: |
35 | 35 | logger.warning("Empty text provided for embedding") |
36 | 36 | return None |
37 | | - |
| 37 | + |
38 | 38 | token_count = count_tokens(text) |
39 | 39 | if token_count > MAX_TOKENS: |
40 | 40 | logger.error(f"Text too long for embedding: {token_count} > {MAX_TOKENS} tokens") |
41 | 41 | return None |
42 | | - |
| 42 | + |
43 | 43 | for attempt in range(max_retries): |
44 | 44 | try: |
45 | 45 | response = requests.post( |
@@ -88,11 +88,88 @@ def get_embedding(text: str, timeout: int = 60, max_retries: int = 3) -> Optiona |
88 | 88 | except Exception as e: |
89 | 89 | logger.error(f"Embedding generation failed: {e}") |
90 | 90 | return None |
91 | | - |
| 91 | + |
92 | 92 | logger.error(f"Failed to generate embedding after {max_retries} attempts") |
93 | 93 | return None |
94 | 94 |
|
95 | 95 |
|
| 96 | +def get_embeddings_batch( |
| 97 | + texts: list[str], |
| 98 | + timeout: int = 120, |
| 99 | + max_retries: int = 3 |
| 100 | +) -> Optional[list[list[float]]]: |
| 101 | + """ |
| 102 | + Generate embeddings for multiple texts in a single API call using Ollama's /api/embed endpoint. |
| 103 | +
|
| 104 | + Args: |
| 105 | + texts: List of texts to embed |
| 106 | + timeout: Request timeout in seconds |
| 107 | + max_retries: Maximum retry attempts on failure |
| 108 | +
|
| 109 | + Returns: |
| 110 | + List of embedding vectors (same order as input) or None if failed |
| 111 | + """ |
| 112 | + if not texts: |
| 113 | + return [] |
| 114 | + |
| 115 | + # Filter empty texts |
| 116 | + valid_texts = [t for t in texts if t and len(t.strip()) > 0] |
| 117 | + if not valid_texts: |
| 118 | + logger.warning("All texts empty for batch embedding") |
| 119 | + return None |
| 120 | + |
| 121 | + for attempt in range(max_retries): |
| 122 | + try: |
| 123 | + response = requests.post( |
| 124 | + f"{OLLAMA_URL}/api/embed", |
| 125 | + json={ |
| 126 | + "model": EMBEDDING_MODEL, |
| 127 | + "input": valid_texts |
| 128 | + }, |
| 129 | + timeout=timeout |
| 130 | + ) |
| 131 | + response.raise_for_status() |
| 132 | + result = response.json() |
| 133 | + |
| 134 | + if "embeddings" not in result: |
| 135 | + logger.error(f"No embeddings in batch response: {result}") |
| 136 | + return None |
| 137 | + |
| 138 | + return result["embeddings"] |
| 139 | + |
| 140 | + except requests.exceptions.Timeout: |
| 141 | + logger.warning(f"Batch embedding timeout (attempt {attempt+1}/{max_retries}), retrying...") |
| 142 | + time.sleep(2 ** attempt) |
| 143 | + continue |
| 144 | + |
| 145 | + except requests.exceptions.ConnectionError: |
| 146 | + logger.warning(f"Connection error (attempt {attempt+1}/{max_retries}), retrying...") |
| 147 | + if attempt < max_retries - 1: |
| 148 | + time.sleep(2 ** attempt) |
| 149 | + continue |
| 150 | + |
| 151 | + except requests.exceptions.HTTPError as e: |
| 152 | + if e.response.status_code == 429: |
| 153 | + wait = int(e.response.headers.get('Retry-After', '5')) |
| 154 | + logger.warning(f"Rate limited, waiting {wait}s (attempt {attempt+1}/{max_retries})") |
| 155 | + time.sleep(wait) |
| 156 | + continue |
| 157 | + elif e.response.status_code == 500: |
| 158 | + logger.warning(f"Ollama 500 error (attempt {attempt+1}/{max_retries}), retrying...") |
| 159 | + time.sleep(2 ** attempt) |
| 160 | + continue |
| 161 | + else: |
| 162 | + logger.error(f"HTTP error from Ollama: {e.response.status_code} - {e.response.text[:200]}") |
| 163 | + return None |
| 164 | + |
| 165 | + except Exception as e: |
| 166 | + logger.error(f"Batch embedding generation failed: {e}") |
| 167 | + return None |
| 168 | + |
| 169 | + logger.error(f"Failed to generate batch embeddings after {max_retries} attempts") |
| 170 | + return None |
| 171 | + |
| 172 | + |
96 | 173 | def safe_embed_chunk( |
97 | 174 | chunk: dict, |
98 | 175 | max_tokens: int = MAX_TOKENS, |
@@ -161,39 +238,78 @@ def safe_embed_chunk( |
161 | 238 |
|
162 | 239 | def batch_embed_chunks( |
163 | 240 | chunks: list[dict], |
164 | | - max_tokens: int = MAX_TOKENS |
| 241 | + max_tokens: int = MAX_TOKENS, |
| 242 | + batch_size: int = 10 |
165 | 243 | ) -> list[dict]: |
166 | 244 | """ |
167 | | - Embed multiple chunks with progress tracking. |
168 | | - |
| 245 | + Embed multiple chunks using batch API for better performance. |
| 246 | +
|
| 247 | + Uses Ollama's /api/embed endpoint to embed multiple texts in a single request, |
| 248 | + reducing API calls from N to N/batch_size. |
| 249 | +
|
169 | 250 | Args: |
170 | | - chunks: List of chunk dictionaries |
| 251 | + chunks: List of chunk dictionaries with 'text' key |
171 | 252 | max_tokens: Maximum tokens per chunk |
172 | | - |
| 253 | + batch_size: Number of texts to embed in a single API call |
| 254 | +
|
173 | 255 | Returns: |
174 | 256 | List of successfully embedded chunks (flattened if re-chunking occurred) |
175 | 257 | """ |
| 258 | + if not chunks: |
| 259 | + return [] |
| 260 | + |
| 261 | + # First pass: validate and prepare chunks, handle oversized ones |
| 262 | + valid_chunks = [] |
| 263 | + for chunk in chunks: |
| 264 | + text = chunk.get('text', '') |
| 265 | + if not text or len(text.strip()) == 0: |
| 266 | + continue |
| 267 | + |
| 268 | + token_count = count_tokens(text) |
| 269 | + if token_count > max_tokens: |
| 270 | + # Re-chunk oversized chunk |
| 271 | + logger.info(f"Re-chunking oversized chunk ({token_count} tokens)") |
| 272 | + from .chunking import fine_chunk_text |
| 273 | + sub_chunks = fine_chunk_text([text], target_tokens=max_tokens // 2, overlap_tokens=50) |
| 274 | + valid_chunks.extend(sub_chunks) |
| 275 | + else: |
| 276 | + valid_chunks.append(chunk) |
| 277 | + |
| 278 | + if not valid_chunks: |
| 279 | + logger.warning("No valid chunks to embed") |
| 280 | + return [] |
| 281 | + |
| 282 | + # Second pass: batch embed |
176 | 283 | embedded_chunks = [] |
177 | 284 | failed_count = 0 |
178 | | - |
179 | | - for i, chunk in enumerate(chunks): |
180 | | - if i % 10 == 0: |
181 | | - logger.info(f"Embedding progress: {i}/{len(chunks)}") |
182 | | - |
183 | | - result = safe_embed_chunk(chunk, max_tokens) |
184 | | - |
185 | | - if result is None: |
186 | | - failed_count += 1 |
187 | | - continue |
188 | | - |
189 | | - # Handle both single chunk and list of re-chunked chunks |
190 | | - if isinstance(result, list): |
191 | | - embedded_chunks.extend(result) |
| 285 | + |
| 286 | + for i in range(0, len(valid_chunks), batch_size): |
| 287 | + batch = valid_chunks[i:i + batch_size] |
| 288 | + texts = [c.get('text', '') for c in batch] |
| 289 | + |
| 290 | + logger.info(f"Embedding batch {i // batch_size + 1}/{(len(valid_chunks) + batch_size - 1) // batch_size} ({len(batch)} chunks)") |
| 291 | + |
| 292 | + embeddings = get_embeddings_batch(texts) |
| 293 | + |
| 294 | + if embeddings is None: |
| 295 | + # Fallback to single embedding if batch fails |
| 296 | + logger.warning("Batch embedding failed, falling back to single embedding") |
| 297 | + for chunk in batch: |
| 298 | + embedding = get_embedding(chunk.get('text', '')) |
| 299 | + if embedding: |
| 300 | + chunk['embedding'] = embedding |
| 301 | + chunk['embedding_model'] = EMBEDDING_MODEL |
| 302 | + embedded_chunks.append(chunk) |
| 303 | + else: |
| 304 | + failed_count += 1 |
192 | 305 | else: |
193 | | - embedded_chunks.append(result) |
194 | | - |
| 306 | + for chunk, embedding in zip(batch, embeddings): |
| 307 | + chunk['embedding'] = embedding |
| 308 | + chunk['embedding_model'] = EMBEDDING_MODEL |
| 309 | + embedded_chunks.append(chunk) |
| 310 | + |
195 | 311 | if failed_count > 0: |
196 | | - logger.warning(f"Failed to embed {failed_count}/{len(chunks)} chunks") |
197 | | - |
| 312 | + logger.warning(f"Failed to embed {failed_count}/{len(valid_chunks)} chunks") |
| 313 | + |
198 | 314 | logger.info(f"Successfully embedded {len(embedded_chunks)} chunks") |
199 | 315 | return embedded_chunks |
0 commit comments