From b1b28f9daf083132ea7f58ba7538ae555bf4eb31 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:52:15 +0800 Subject: [PATCH] fix: single-flight discovery engine mode detection --- .../adk/tools/discovery_engine_search_tool.py | 38 +++++++---- .../test_discovery_engine_search_tool.py | 63 +++++++++++++++++++ 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/src/google/adk/tools/discovery_engine_search_tool.py b/src/google/adk/tools/discovery_engine_search_tool.py index eea843c35f..1d31b42824 100644 --- a/src/google/adk/tools/discovery_engine_search_tool.py +++ b/src/google/adk/tools/discovery_engine_search_tool.py @@ -19,6 +19,7 @@ import json import logging import re +import threading from typing import Any from typing import Optional @@ -172,6 +173,7 @@ def __init__( self._filter = filter self._max_results = max_results self._search_result_mode = search_result_mode + self._search_result_mode_lock = threading.Lock() self._location = location credentials, _ = google.auth.default() @@ -204,19 +206,29 @@ def discovery_engine_search( if mode is not None: return self._do_search(query, mode) - # Auto-detect: try CHUNKS first, fall back to DOCUMENTS - # if the datastore requires it. - try: - return self._do_search(query, SearchResultMode.CHUNKS) - except GoogleAPICallError as e: - if _STRUCTURED_STORE_ERROR_PATTERN.search(str(e)): - logger.info( - 'CHUNKS mode failed for structured datastore,' - ' retrying with DOCUMENTS mode.' - ) - self._search_result_mode = SearchResultMode.DOCUMENTS - return self._do_search(query, SearchResultMode.DOCUMENTS) - raise + # Auto-detect is per datastore, not per query. Keep the probe + # single-flight so concurrent first calls do not all spend a CHUNKS + # request before learning the same DOCUMENTS fallback. + with self._search_result_mode_lock: + mode = self._search_result_mode + if mode is None: + try: + result = self._do_search(query, SearchResultMode.CHUNKS) + except GoogleAPICallError as e: + if _STRUCTURED_STORE_ERROR_PATTERN.search(str(e)): + logger.info( + 'CHUNKS mode failed for structured datastore,' + ' retrying with DOCUMENTS mode.' + ) + self._search_result_mode = SearchResultMode.DOCUMENTS + mode = SearchResultMode.DOCUMENTS + else: + raise + else: + self._search_result_mode = SearchResultMode.CHUNKS + return result + + return self._do_search(query, mode) except GoogleAPICallError as e: return {'status': 'error', 'error_message': str(e)} diff --git a/tests/unittests/tools/test_discovery_engine_search_tool.py b/tests/unittests/tools/test_discovery_engine_search_tool.py index a744be7c39..f603d82fa1 100644 --- a/tests/unittests/tools/test_discovery_engine_search_tool.py +++ b/tests/unittests/tools/test_discovery_engine_search_tool.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures +import threading +import time from unittest import mock from google.adk.tools import discovery_engine_search_tool @@ -489,6 +492,66 @@ def test_auto_detect_falls_back_to_documents(self, mock_search_client): # Mode should be persisted so subsequent calls skip the retry. assert tool._search_result_mode == SearchResultMode.DOCUMENTS + @mock.patch.object( + discoveryengine, + "SearchServiceClient", + ) + def test_auto_detect_singleflights_structured_fallback( + self, mock_search_client + ): + """Concurrent cold calls should share one CHUNKS probe.""" + spec_cls = discoveryengine.SearchRequest.ContentSearchSpec + worker_count = 8 + start_barrier = threading.Barrier(worker_count) + search_lock = threading.Lock() + search_modes = [] + structured_error = exceptions.InvalidArgument( + "`content_search_spec.search_result_mode` must be set to" + " SearchRequest.ContentSearchSpec.SearchResultMode.DOCUMENTS" + " when the engine contains structured data store." + ) + mock_doc = discoveryengine.Document( + name="projects/p/locations/l/doc1", + id="doc1", + struct_data={ + "title": "Jira Issue", + "uri": "https://jira.example.com/123", + "summary": "Bug fix", + }, + ) + mock_doc_response = discoveryengine.SearchResponse() + mock_doc_response.results = [ + discoveryengine.SearchResponse.SearchResult(document=mock_doc) + ] + + def search(request): + mode = request.content_search_spec.search_result_mode + with search_lock: + search_modes.append(mode) + if mode == spec_cls.SearchResultMode.CHUNKS: + time.sleep(0.05) + raise structured_error + return mock_doc_response + + mock_search_client.return_value.search.side_effect = search + tool = DiscoveryEngineSearchTool(data_store_id="test_data_store") + + def run_search(index): + start_barrier.wait(timeout=5) + return tool.discovery_engine_search(f"test query {index}") + + with concurrent.futures.ThreadPoolExecutor( + max_workers=worker_count + ) as executor: + results = list(executor.map(run_search, range(worker_count))) + + assert all(result["status"] == "success" for result in results) + assert search_modes.count(spec_cls.SearchResultMode.CHUNKS) == 1 + assert ( + search_modes.count(spec_cls.SearchResultMode.DOCUMENTS) == worker_count + ) + assert tool._search_result_mode == SearchResultMode.DOCUMENTS + @mock.patch.object( discoveryengine, "SearchServiceClient",