Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions src/google/adk/tools/discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import logging
import re
import threading
from typing import Any
from typing import Optional

Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
self._filter = filter
self._max_results = max_results
self._search_result_mode = search_result_mode
self._search_result_mode_lock = threading.Lock()
self._location = location

credentials, _ = google.auth.default()
Expand Down Expand Up @@ -204,19 +206,29 @@ def discovery_engine_search(
if mode is not None:
return self._do_search(query, mode)

# Auto-detect: try CHUNKS first, fall back to DOCUMENTS
# if the datastore requires it.
try:
return self._do_search(query, SearchResultMode.CHUNKS)
except GoogleAPICallError as e:
if _STRUCTURED_STORE_ERROR_PATTERN.search(str(e)):
logger.info(
'CHUNKS mode failed for structured datastore,'
' retrying with DOCUMENTS mode.'
)
self._search_result_mode = SearchResultMode.DOCUMENTS
return self._do_search(query, SearchResultMode.DOCUMENTS)
raise
# Auto-detect is per datastore, not per query. Keep the probe
# single-flight so concurrent first calls do not all spend a CHUNKS
# request before learning the same DOCUMENTS fallback.
with self._search_result_mode_lock:
mode = self._search_result_mode
if mode is None:
try:
result = self._do_search(query, SearchResultMode.CHUNKS)
except GoogleAPICallError as e:
if _STRUCTURED_STORE_ERROR_PATTERN.search(str(e)):
logger.info(
'CHUNKS mode failed for structured datastore,'
' retrying with DOCUMENTS mode.'
)
self._search_result_mode = SearchResultMode.DOCUMENTS
mode = SearchResultMode.DOCUMENTS
else:
raise
else:
self._search_result_mode = SearchResultMode.CHUNKS
return result

return self._do_search(query, mode)
except GoogleAPICallError as e:
return {'status': 'error', 'error_message': str(e)}

Expand Down
63 changes: 63 additions & 0 deletions tests/unittests/tools/test_discovery_engine_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent.futures
import threading
import time
from unittest import mock

from google.adk.tools import discovery_engine_search_tool
Expand Down Expand Up @@ -489,6 +492,66 @@ def test_auto_detect_falls_back_to_documents(self, mock_search_client):
# Mode should be persisted so subsequent calls skip the retry.
assert tool._search_result_mode == SearchResultMode.DOCUMENTS

@mock.patch.object(
discoveryengine,
"SearchServiceClient",
)
def test_auto_detect_singleflights_structured_fallback(
self, mock_search_client
):
"""Concurrent cold calls should share one CHUNKS probe."""
spec_cls = discoveryengine.SearchRequest.ContentSearchSpec
worker_count = 8
start_barrier = threading.Barrier(worker_count)
search_lock = threading.Lock()
search_modes = []
structured_error = exceptions.InvalidArgument(
"`content_search_spec.search_result_mode` must be set to"
" SearchRequest.ContentSearchSpec.SearchResultMode.DOCUMENTS"
" when the engine contains structured data store."
)
mock_doc = discoveryengine.Document(
name="projects/p/locations/l/doc1",
id="doc1",
struct_data={
"title": "Jira Issue",
"uri": "https://jira.example.com/123",
"summary": "Bug fix",
},
)
mock_doc_response = discoveryengine.SearchResponse()
mock_doc_response.results = [
discoveryengine.SearchResponse.SearchResult(document=mock_doc)
]

def search(request):
mode = request.content_search_spec.search_result_mode
with search_lock:
search_modes.append(mode)
if mode == spec_cls.SearchResultMode.CHUNKS:
time.sleep(0.05)
raise structured_error
return mock_doc_response

mock_search_client.return_value.search.side_effect = search
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")

def run_search(index):
start_barrier.wait(timeout=5)
return tool.discovery_engine_search(f"test query {index}")

with concurrent.futures.ThreadPoolExecutor(
max_workers=worker_count
) as executor:
results = list(executor.map(run_search, range(worker_count)))

assert all(result["status"] == "success" for result in results)
assert search_modes.count(spec_cls.SearchResultMode.CHUNKS) == 1
assert (
search_modes.count(spec_cls.SearchResultMode.DOCUMENTS) == worker_count
)
assert tool._search_result_mode == SearchResultMode.DOCUMENTS

@mock.patch.object(
discoveryengine,
"SearchServiceClient",
Expand Down