Skip to content

Commit 91d811f

Browse files
fix(core): replace blocking run_async_tasks with asyncio.gather (#20795)
replace blocking run_async_tasks with asyncio.gather
1 parent 78c0a01 commit 91d811f

File tree

2 files changed

+117
-3
lines changed

2 files changed

+117
-3
lines changed

llama-index-core/llama_index/core/query_engine/router_query_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import asyncio
12
import logging
23
from typing import Callable, Generator, List, Optional, Sequence, Any
34

4-
from llama_index.core.async_utils import run_async_tasks
55
from llama_index.core.base.base_query_engine import BaseQueryEngine
66
from llama_index.core.base.base_retriever import BaseRetriever
77
from llama_index.core.base.base_selector import BaseSelector
@@ -220,7 +220,7 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
220220
selected_query_engine = self._query_engines[engine_ind]
221221
tasks.append(selected_query_engine.aquery(query_bundle))
222222

223-
responses = run_async_tasks(tasks)
223+
responses = await asyncio.gather(*tasks)
224224
if len(responses) > 1:
225225
final_response = await acombine_responses(
226226
self._summarizer, responses, query_bundle
@@ -380,7 +380,7 @@ async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
380380
for query_engine_tool in query_engine_tools:
381381
query_engine = query_engine_tool.query_engine
382382
tasks.append(query_engine.aquery(query_bundle))
383-
responses = run_async_tasks(tasks)
383+
responses = await asyncio.gather(*tasks)
384384
if len(responses) > 1:
385385
final_response = await acombine_responses(
386386
self._summarizer, responses, query_bundle
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import asyncio
2+
from unittest.mock import AsyncMock, MagicMock
3+
4+
import pytest
5+
6+
from llama_index.core.base.base_selector import (
7+
BaseSelector,
8+
SelectorResult,
9+
SingleSelection,
10+
)
11+
from llama_index.core.base.response.schema import Response
12+
from llama_index.core.llms.mock import MockLLM
13+
from llama_index.core.query_engine.router_query_engine import (
14+
RouterQueryEngine,
15+
ToolRetrieverRouterQueryEngine,
16+
)
17+
from llama_index.core.tools.types import ToolMetadata
18+
19+
20+
class _AlwaysMultiSelector(BaseSelector):
21+
def _get_prompts(self):
22+
return {}
23+
24+
def _update_prompts(self, prompts):
25+
pass
26+
27+
def _select(self, choices, query):
28+
return SelectorResult(
29+
selections=[
30+
SingleSelection(index=i, reason="") for i in range(len(choices))
31+
]
32+
)
33+
34+
async def _aselect(self, choices, query):
35+
return self._select(choices, query)
36+
37+
38+
def _make_query_engine_tool(name: str):
39+
async def _fake_aquery(_):
40+
await asyncio.sleep(0.05)
41+
return Response(response="ok")
42+
43+
engine = MagicMock()
44+
engine.aquery = AsyncMock(side_effect=_fake_aquery)
45+
46+
tool = MagicMock()
47+
tool.query_engine = engine
48+
tool.metadata = ToolMetadata(name=name, description=name)
49+
return tool
50+
51+
52+
class _MockSummarizer:
53+
async def aget_response(self, *args, **kwargs):
54+
return "combined"
55+
56+
57+
async def _assert_not_blocked(coro) -> None:
58+
loop = asyncio.get_running_loop()
59+
start = loop.time()
60+
ran_at = None
61+
62+
async def _background():
63+
nonlocal ran_at
64+
await asyncio.sleep(0.01)
65+
ran_at = loop.time()
66+
67+
bg_task = asyncio.create_task(_background())
68+
await asyncio.sleep(0)
69+
await coro
70+
await bg_task
71+
72+
assert ran_at is not None, "background task never ran"
73+
assert (ran_at - start) < 0.04, (
74+
f"background task finished {ran_at - start:.3f}s after start "
75+
f"(expected < 0.04s) — the event loop was likely blocked"
76+
)
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_router_aquery_does_not_block_event_loop():
81+
tool_a = _make_query_engine_tool("a")
82+
tool_b = _make_query_engine_tool("b")
83+
84+
router = RouterQueryEngine(
85+
selector=_AlwaysMultiSelector(),
86+
query_engine_tools=[tool_a, tool_b],
87+
llm=MockLLM(),
88+
summarizer=_MockSummarizer(),
89+
)
90+
91+
await _assert_not_blocked(router.aquery("test query"))
92+
93+
assert tool_a.query_engine.aquery.call_count == 1
94+
assert tool_b.query_engine.aquery.call_count == 1
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_tool_retriever_router_aquery_does_not_block_event_loop():
99+
tool_a = _make_query_engine_tool("a")
100+
tool_b = _make_query_engine_tool("b")
101+
102+
retriever = MagicMock()
103+
retriever.retrieve = MagicMock(return_value=[tool_a, tool_b])
104+
105+
router = ToolRetrieverRouterQueryEngine(
106+
retriever=retriever,
107+
llm=MockLLM(),
108+
summarizer=_MockSummarizer(),
109+
)
110+
111+
await _assert_not_blocked(router.aquery("test query"))
112+
113+
assert tool_a.query_engine.aquery.call_count == 1
114+
assert tool_b.query_engine.aquery.call_count == 1

0 commit comments

Comments
 (0)