Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,14 +913,21 @@ async def stream_query(request: Request):
output = await _invoke_callable_or_raise(method, parsed.input or {})

if inspect.isgenerator(output):
# Sentinel-based exhaustion check. We cannot rely on catching
# StopIteration here: when ``next(iterator)`` is called inside the
# threadpool worker, the StopIteration propagates out of the
# ``run_in_threadpool`` coroutine frame, and Python (PEP 479) converts
# it to ``RuntimeError("coroutine raised StopIteration")`` before the
# ``except StopIteration`` clause can ever see it. Passing a default to
# ``next`` avoids raising at the boundary entirely.
_SENTINEL = object()

async def _aiter_from_iter(iterator):
while True:
try:
chunk = await run_in_threadpool(next, iterator)
yield chunk
except StopIteration:
chunk = await run_in_threadpool(next, iterator, _SENTINEL)
if chunk is _SENTINEL:
break
yield chunk

content_iter = _aiter_from_iter(output)
else:
Expand Down
78 changes: 78 additions & 0 deletions tests/unittests/cli/test_fast_api.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,63 @@ async def stream_query_impl(**kwargs):
yield client


@pytest.fixture
def test_app_with_gemini_enterprise_sync_stream(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
monkeypatch,
):
"""Like test_app_with_gemini_enterprise but stream_query is a sync generator.

This exercises the inspect.isgenerator() branch in stream_reasoning_engine,
where the sync iterator is adapted to an async iterator via a threadpool.
"""
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "test-project")
mock_agent_loader.list_agents = MagicMock(
return_value=["test_app", "gemini_app"]
)

mock_adk_app_instance = MagicMock()
mock_adk_app_instance._tmpl_attrs = {}

def stream_query_impl(**kwargs):
yield {"chunk": 1, "kwargs": kwargs}
yield {"chunk": 2, "kwargs": kwargs}

mock_adk_app_instance.stream_query = stream_query_impl

with (
patch("google.auth.default", return_value=(MagicMock(), "test-project")),
patch("vertexai.init", new_callable=MagicMock),
patch(
"vertexai.agent_engines.AdkApp", return_value=mock_adk_app_instance
),
patch("google.adk.agents.Agent", new_callable=MagicMock),
patch(
"google.adk.telemetry._agent_engine.TopSpanProcessor",
new_callable=MagicMock,
),
patch(
"google.adk.telemetry._agent_engine.get_propagated_context",
new_callable=MagicMock,
),
):
client = _create_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
gemini_enterprise_app_name="gemini_app",
)
yield client


#################################################
# Test Cases
#################################################
Expand Down Expand Up @@ -3331,5 +3388,26 @@ def test_gemini_stream_reasoning_engine_missing_class_method(
assert response.status_code == 400


def test_gemini_stream_reasoning_engine_sync_generator(
test_app_with_gemini_enterprise_sync_stream,
):
"""Regression test: a synchronous streaming class_method must not raise.

A sync generator is adapted to an async iterator via run_in_threadpool. The
adapter must not rely on catching StopIteration across the await boundary,
since Python (PEP 479) converts an escaping StopIteration into
RuntimeError("coroutine raised StopIteration") after the final chunk.
"""
response = test_app_with_gemini_enterprise_sync_stream.post(
"/api/stream_reasoning_engine",
json={"class_method": "stream_query", "input": {"arg1": 1}},
)
assert response.status_code == 200
lines = response.text.strip().split("\n")
assert len(lines) == 2
assert json.loads(lines[0]) == {"chunk": 1, "kwargs": {"arg1": 1}}
assert json.loads(lines[1]) == {"chunk": 2, "kwargs": {"arg1": 1}}


if __name__ == "__main__":
pytest.main(["-xvs", __file__])