Skip to content

Commit 9406d17

Browse files
perf(embedding): implementa batch embedding con endpoint Ollama /api/embed per ridurre chiamate API da N a N/10
1 parent 7892120 commit 9406d17

File tree

2 files changed

+147
-28
lines changed

2 files changed

+147
-28
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.
77

88
## [Unreleased]
99

10+
### Changed
11+
- Batch embedding using Ollama /api/embed endpoint, reducing API calls from N to N/10 for faster uploads
12+
1013
### Fixed
1114
- Silenced verbose httpx/httpcore logs that spammed 60+ lines per file upload
1215

lib/embedding.py

Lines changed: 144 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,24 @@
2222
def get_embedding(text: str, timeout: int = 60, max_retries: int = 3) -> Optional[list[float]]:
2323
"""
2424
Generate embedding using Ollama with retry logic.
25-
25+
2626
Args:
2727
text: Text to embed
2828
timeout: Request timeout in seconds
2929
max_retries: Maximum retry attempts on failure
30-
30+
3131
Returns:
3232
Embedding vector or None if failed
3333
"""
3434
if not text or len(text.strip()) == 0:
3535
logger.warning("Empty text provided for embedding")
3636
return None
37-
37+
3838
token_count = count_tokens(text)
3939
if token_count > MAX_TOKENS:
4040
logger.error(f"Text too long for embedding: {token_count} > {MAX_TOKENS} tokens")
4141
return None
42-
42+
4343
for attempt in range(max_retries):
4444
try:
4545
response = requests.post(
@@ -88,11 +88,88 @@ def get_embedding(text: str, timeout: int = 60, max_retries: int = 3) -> Optiona
8888
except Exception as e:
8989
logger.error(f"Embedding generation failed: {e}")
9090
return None
91-
91+
9292
logger.error(f"Failed to generate embedding after {max_retries} attempts")
9393
return None
9494

9595

96+
def get_embeddings_batch(
97+
texts: list[str],
98+
timeout: int = 120,
99+
max_retries: int = 3
100+
) -> Optional[list[list[float]]]:
101+
"""
102+
Generate embeddings for multiple texts in a single API call using Ollama's /api/embed endpoint.
103+
104+
Args:
105+
texts: List of texts to embed
106+
timeout: Request timeout in seconds
107+
max_retries: Maximum retry attempts on failure
108+
109+
Returns:
110+
List of embedding vectors (same order as input) or None if failed
111+
"""
112+
if not texts:
113+
return []
114+
115+
# Filter empty texts
116+
valid_texts = [t for t in texts if t and len(t.strip()) > 0]
117+
if not valid_texts:
118+
logger.warning("All texts empty for batch embedding")
119+
return None
120+
121+
for attempt in range(max_retries):
122+
try:
123+
response = requests.post(
124+
f"{OLLAMA_URL}/api/embed",
125+
json={
126+
"model": EMBEDDING_MODEL,
127+
"input": valid_texts
128+
},
129+
timeout=timeout
130+
)
131+
response.raise_for_status()
132+
result = response.json()
133+
134+
if "embeddings" not in result:
135+
logger.error(f"No embeddings in batch response: {result}")
136+
return None
137+
138+
return result["embeddings"]
139+
140+
except requests.exceptions.Timeout:
141+
logger.warning(f"Batch embedding timeout (attempt {attempt+1}/{max_retries}), retrying...")
142+
time.sleep(2 ** attempt)
143+
continue
144+
145+
except requests.exceptions.ConnectionError:
146+
logger.warning(f"Connection error (attempt {attempt+1}/{max_retries}), retrying...")
147+
if attempt < max_retries - 1:
148+
time.sleep(2 ** attempt)
149+
continue
150+
151+
except requests.exceptions.HTTPError as e:
152+
if e.response.status_code == 429:
153+
wait = int(e.response.headers.get('Retry-After', '5'))
154+
logger.warning(f"Rate limited, waiting {wait}s (attempt {attempt+1}/{max_retries})")
155+
time.sleep(wait)
156+
continue
157+
elif e.response.status_code == 500:
158+
logger.warning(f"Ollama 500 error (attempt {attempt+1}/{max_retries}), retrying...")
159+
time.sleep(2 ** attempt)
160+
continue
161+
else:
162+
logger.error(f"HTTP error from Ollama: {e.response.status_code} - {e.response.text[:200]}")
163+
return None
164+
165+
except Exception as e:
166+
logger.error(f"Batch embedding generation failed: {e}")
167+
return None
168+
169+
logger.error(f"Failed to generate batch embeddings after {max_retries} attempts")
170+
return None
171+
172+
96173
def safe_embed_chunk(
97174
chunk: dict,
98175
max_tokens: int = MAX_TOKENS,
@@ -161,39 +238,78 @@ def safe_embed_chunk(
161238

162239
def batch_embed_chunks(
163240
chunks: list[dict],
164-
max_tokens: int = MAX_TOKENS
241+
max_tokens: int = MAX_TOKENS,
242+
batch_size: int = 10
165243
) -> list[dict]:
166244
"""
167-
Embed multiple chunks with progress tracking.
168-
245+
Embed multiple chunks using batch API for better performance.
246+
247+
Uses Ollama's /api/embed endpoint to embed multiple texts in a single request,
248+
reducing API calls from N to N/batch_size.
249+
169250
Args:
170-
chunks: List of chunk dictionaries
251+
chunks: List of chunk dictionaries with 'text' key
171252
max_tokens: Maximum tokens per chunk
172-
253+
batch_size: Number of texts to embed in a single API call
254+
173255
Returns:
174256
List of successfully embedded chunks (flattened if re-chunking occurred)
175257
"""
258+
if not chunks:
259+
return []
260+
261+
# First pass: validate and prepare chunks, handle oversized ones
262+
valid_chunks = []
263+
for chunk in chunks:
264+
text = chunk.get('text', '')
265+
if not text or len(text.strip()) == 0:
266+
continue
267+
268+
token_count = count_tokens(text)
269+
if token_count > max_tokens:
270+
# Re-chunk oversized chunk
271+
logger.info(f"Re-chunking oversized chunk ({token_count} tokens)")
272+
from .chunking import fine_chunk_text
273+
sub_chunks = fine_chunk_text([text], target_tokens=max_tokens // 2, overlap_tokens=50)
274+
valid_chunks.extend(sub_chunks)
275+
else:
276+
valid_chunks.append(chunk)
277+
278+
if not valid_chunks:
279+
logger.warning("No valid chunks to embed")
280+
return []
281+
282+
# Second pass: batch embed
176283
embedded_chunks = []
177284
failed_count = 0
178-
179-
for i, chunk in enumerate(chunks):
180-
if i % 10 == 0:
181-
logger.info(f"Embedding progress: {i}/{len(chunks)}")
182-
183-
result = safe_embed_chunk(chunk, max_tokens)
184-
185-
if result is None:
186-
failed_count += 1
187-
continue
188-
189-
# Handle both single chunk and list of re-chunked chunks
190-
if isinstance(result, list):
191-
embedded_chunks.extend(result)
285+
286+
for i in range(0, len(valid_chunks), batch_size):
287+
batch = valid_chunks[i:i + batch_size]
288+
texts = [c.get('text', '') for c in batch]
289+
290+
logger.info(f"Embedding batch {i // batch_size + 1}/{(len(valid_chunks) + batch_size - 1) // batch_size} ({len(batch)} chunks)")
291+
292+
embeddings = get_embeddings_batch(texts)
293+
294+
if embeddings is None:
295+
# Fallback to single embedding if batch fails
296+
logger.warning("Batch embedding failed, falling back to single embedding")
297+
for chunk in batch:
298+
embedding = get_embedding(chunk.get('text', ''))
299+
if embedding:
300+
chunk['embedding'] = embedding
301+
chunk['embedding_model'] = EMBEDDING_MODEL
302+
embedded_chunks.append(chunk)
303+
else:
304+
failed_count += 1
192305
else:
193-
embedded_chunks.append(result)
194-
306+
for chunk, embedding in zip(batch, embeddings):
307+
chunk['embedding'] = embedding
308+
chunk['embedding_model'] = EMBEDDING_MODEL
309+
embedded_chunks.append(chunk)
310+
195311
if failed_count > 0:
196-
logger.warning(f"Failed to embed {failed_count}/{len(chunks)} chunks")
197-
312+
logger.warning(f"Failed to embed {failed_count}/{len(valid_chunks)} chunks")
313+
198314
logger.info(f"Successfully embedded {len(embedded_chunks)} chunks")
199315
return embedded_chunks

0 commit comments

Comments
 (0)