Skip to content

Commit 0852a27

Browse files
keenanpeppershinxi
authored andcommitted
core: Make abatch_as_completed respect max_concurrency (#29426)
- **Description:** Add tests for respecting max_concurrency and implement it for abatch_as_completed so that test passes - **Issue:** #29425 - **Dependencies:** none - **Twitter handle:** keenanpepper
1 parent f5fbad2 commit 0852a27

File tree

2 files changed

+155
-3
lines changed

2 files changed

+155
-3
lines changed

libs/core/langchain_core/runnables/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
accepts_config,
7272
accepts_run_manager,
7373
asyncio_accepts_context,
74+
gated_coro,
7475
gather_with_concurrency,
7576
get_function_first_arg_dict_keys,
7677
get_function_nonlocals,
@@ -952,8 +953,11 @@ async def abatch_as_completed(
952953
return
953954

954955
configs = get_config_list(config, len(inputs))
956+
# Get max_concurrency from first config, defaulting to None (unlimited)
957+
max_concurrency = configs[0].get("max_concurrency") if configs else None
958+
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
955959

956-
async def ainvoke(
960+
async def ainvoke_task(
957961
i: int, input: Input, config: RunnableConfig
958962
) -> tuple[int, Union[Output, Exception]]:
959963
if return_exceptions:
@@ -965,10 +969,14 @@ async def ainvoke(
965969
out = e
966970
else:
967971
out = await self.ainvoke(input, config, **kwargs)
968-
969972
return (i, out)
970973

971-
coros = map(ainvoke, range(len(inputs)), inputs, configs)
974+
coros = [
975+
gated_coro(semaphore, ainvoke_task(i, input, config))
976+
if semaphore
977+
else ainvoke_task(i, input, config)
978+
for i, (input, config) in enumerate(zip(inputs, configs))
979+
]
972980

973981
for coro in asyncio.as_completed(coros):
974982
yield await coro
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Test concurrency behavior of batch and async batch operations."""
2+
3+
import asyncio
4+
import time
5+
from typing import Any
6+
7+
import pytest
8+
9+
from langchain_core.runnables import RunnableConfig, RunnableLambda
10+
from langchain_core.runnables.base import Runnable
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_abatch_concurrency() -> None:
15+
"""Test that abatch respects max_concurrency."""
16+
running_tasks = 0
17+
max_running_tasks = 0
18+
lock = asyncio.Lock()
19+
20+
async def tracked_function(x: Any) -> str:
21+
nonlocal running_tasks, max_running_tasks
22+
async with lock:
23+
running_tasks += 1
24+
max_running_tasks = max(max_running_tasks, running_tasks)
25+
26+
await asyncio.sleep(0.1) # Simulate work
27+
28+
async with lock:
29+
running_tasks -= 1
30+
31+
return f"Completed {x}"
32+
33+
runnable: Runnable = RunnableLambda(tracked_function)
34+
num_tasks = 10
35+
max_concurrency = 3
36+
37+
config = RunnableConfig(max_concurrency=max_concurrency)
38+
results = await runnable.abatch(list(range(num_tasks)), config=config)
39+
40+
assert len(results) == num_tasks
41+
assert max_running_tasks <= max_concurrency
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_abatch_as_completed_concurrency() -> None:
46+
"""Test that abatch_as_completed respects max_concurrency."""
47+
running_tasks = 0
48+
max_running_tasks = 0
49+
lock = asyncio.Lock()
50+
51+
async def tracked_function(x: Any) -> str:
52+
nonlocal running_tasks, max_running_tasks
53+
async with lock:
54+
running_tasks += 1
55+
max_running_tasks = max(max_running_tasks, running_tasks)
56+
57+
await asyncio.sleep(0.1) # Simulate work
58+
59+
async with lock:
60+
running_tasks -= 1
61+
62+
return f"Completed {x}"
63+
64+
runnable: Runnable = RunnableLambda(tracked_function)
65+
num_tasks = 10
66+
max_concurrency = 3
67+
68+
config = RunnableConfig(max_concurrency=max_concurrency)
69+
results = []
70+
async for _idx, result in runnable.abatch_as_completed(
71+
list(range(num_tasks)), config=config
72+
):
73+
results.append(result)
74+
75+
assert len(results) == num_tasks
76+
assert max_running_tasks <= max_concurrency
77+
78+
79+
def test_batch_concurrency() -> None:
80+
"""Test that batch respects max_concurrency."""
81+
running_tasks = 0
82+
max_running_tasks = 0
83+
from threading import Lock
84+
85+
lock = Lock()
86+
87+
def tracked_function(x: Any) -> str:
88+
nonlocal running_tasks, max_running_tasks
89+
with lock:
90+
running_tasks += 1
91+
max_running_tasks = max(max_running_tasks, running_tasks)
92+
93+
time.sleep(0.1) # Simulate work
94+
95+
with lock:
96+
running_tasks -= 1
97+
98+
return f"Completed {x}"
99+
100+
runnable: Runnable = RunnableLambda(tracked_function)
101+
num_tasks = 10
102+
max_concurrency = 3
103+
104+
config = RunnableConfig(max_concurrency=max_concurrency)
105+
results = runnable.batch(list(range(num_tasks)), config=config)
106+
107+
assert len(results) == num_tasks
108+
assert max_running_tasks <= max_concurrency
109+
110+
111+
def test_batch_as_completed_concurrency() -> None:
112+
"""Test that batch_as_completed respects max_concurrency."""
113+
running_tasks = 0
114+
max_running_tasks = 0
115+
from threading import Lock
116+
117+
lock = Lock()
118+
119+
def tracked_function(x: Any) -> str:
120+
nonlocal running_tasks, max_running_tasks
121+
with lock:
122+
running_tasks += 1
123+
max_running_tasks = max(max_running_tasks, running_tasks)
124+
125+
time.sleep(0.1) # Simulate work
126+
127+
with lock:
128+
running_tasks -= 1
129+
130+
return f"Completed {x}"
131+
132+
runnable: Runnable = RunnableLambda(tracked_function)
133+
num_tasks = 10
134+
max_concurrency = 3
135+
136+
config = RunnableConfig(max_concurrency=max_concurrency)
137+
results = []
138+
for _idx, result in runnable.batch_as_completed(
139+
list(range(num_tasks)), config=config
140+
):
141+
results.append(result)
142+
143+
assert len(results) == num_tasks
144+
assert max_running_tasks <= max_concurrency

0 commit comments

Comments
 (0)