Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import concurrent.futures
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any

Expand Down Expand Up @@ -36,6 +37,8 @@
"score_threshold_enabled": False,
}

logger = logging.getLogger(__name__)


class RetrievalService:
# Cache precompiled regular expressions to avoid repeated compilation
Expand Down Expand Up @@ -106,7 +109,12 @@ def retrieve(
)
)

concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED)
if futures:
for future in concurrent.futures.as_completed(futures, timeout=3600):
if exceptions:
for f in futures:
f.cancel()
break

if exceptions:
raise ValueError(";\n".join(exceptions))
Expand Down Expand Up @@ -210,6 +218,7 @@ def keyword_search(
)
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))

@classmethod
Expand Down Expand Up @@ -303,6 +312,7 @@ def embedding_search(
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))

@classmethod
Expand Down Expand Up @@ -351,6 +361,7 @@ def full_text_index_search(
else:
all_documents.extend(documents)
except Exception as e:
logger.error(e, exc_info=True)
exceptions.append(str(e))

@staticmethod
Expand Down Expand Up @@ -662,7 +673,14 @@ def _retrieve(
document_ids_filter=document_ids_filter,
)
)
concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED)
# Use as_completed for early error propagation - cancel remaining futures on first error
if futures:
for future in concurrent.futures.as_completed(futures, timeout=300):
if future.exception():
# Cancel remaining futures to avoid unnecessary waiting
for f in futures:
f.cancel()
break

if exceptions:
raise ValueError(";\n".join(exceptions))
Expand Down
104 changes: 70 additions & 34 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,9 @@ def multiple_retrieve(
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []

if query:
query_thread = threading.Thread(
target=self._multiple_retrieve_thread,
Expand All @@ -534,6 +537,8 @@ def multiple_retrieve(
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(query_thread)
Expand All @@ -557,12 +562,25 @@ def multiple_retrieve(
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
)
all_threads.append(attachment_thread)
attachment_thread.start()
for thread in all_threads:
thread.join()

# Poll threads with short timeout to detect errors quickly (fail-fast)
while any(t.is_alive() for t in all_threads):
for thread in all_threads:
thread.join(timeout=0.1)
if thread_exceptions:
cancel_event.set()
break
if thread_exceptions:
break

if thread_exceptions:
raise thread_exceptions[0]
self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id)

if all_documents:
Expand Down Expand Up @@ -1402,40 +1420,53 @@ def _multiple_retrieve_thread(
score_threshold: float,
query: str | None,
attachment_id: str | None,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
try:
with flask_app.app_context():
threads = []
all_documents_item: list[Document] = []
index_type = None
for dataset in available_datasets:
# Check for cancellation signal
if cancel_event and cancel_event.is_set():
break
index_type = dataset.indexing_technique
document_ids_filter = None
if dataset.provider != "external":
if metadata_condition and not metadata_filter_document_ids:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
if metadata_filter_document_ids:
document_ids = metadata_filter_document_ids.get(dataset.id, [])
if document_ids:
document_ids_filter = document_ids
else:
continue
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": flask_app,
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents_item,
"document_ids_filter": document_ids_filter,
"metadata_condition": metadata_condition,
"attachment_ids": [attachment_id] if attachment_id else None,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()

# Poll threads with short timeout to respond quickly to cancellation
while any(t.is_alive() for t in threads):
for thread in threads:
thread.join(timeout=0.1)
if cancel_event and cancel_event.is_set():
break
if cancel_event and cancel_event.is_set():
break

if reranking_enable:
# do rerank for searched documents
Expand Down Expand Up @@ -1468,3 +1499,8 @@ def _multiple_retrieve_thread(
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,18 @@ def sync_submit(fn, *args, **kwargs):
# In real code, this waits for all futures to complete
# In tests, futures complete immediately, so wait is a no-op
with patch("core.rag.datasource.retrieval_service.concurrent.futures.wait"):
yield mock_executor
# Mock concurrent.futures.as_completed for early error propagation
# In real code, this yields futures as they complete
# In tests, we yield all futures immediately since they're already done
def mock_as_completed(futures_list, timeout=None):
"""Mock as_completed that yields futures immediately."""
yield from futures_list

with patch(
"core.rag.datasource.retrieval_service.concurrent.futures.as_completed",
side_effect=mock_as_completed,
):
yield mock_executor

# ==================== Vector Search Tests ====================

Expand Down
Loading