Skip to content

Commit 626a6d7

Browse files
authored
Fixed semaphore logic (#97)
1 parent 091e703 commit 626a6d7

File tree

2 files changed

+114
-65
lines changed

2 files changed

+114
-65
lines changed

taskiq/receiver/receiver.py

Lines changed: 77 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
broker: AsyncBroker,
4242
executor: Optional[Executor] = None,
4343
validate_params: bool = True,
44-
max_async_tasks: int = 20,
44+
max_async_tasks: "Optional[int]" = None,
4545
) -> None:
4646
self.broker = broker
4747
self.executor = executor
@@ -53,7 +53,14 @@ def __init__(
5353
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
5454
self.task_hints[task.task_name] = get_type_hints(task.original_func)
5555
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
56-
self.sem = asyncio.Semaphore(max_async_tasks)
56+
self.sem: "Optional[asyncio.Semaphore]" = None
57+
if max_async_tasks is not None and max_async_tasks > 0:
58+
self.sem = asyncio.Semaphore(max_async_tasks)
59+
else:
60+
logger.warning(
61+
"Setting unlimited number of async tasks "
62+
+ "can result in undefined behavior",
63+
)
5764

5865
async def callback( # noqa: C901, WPS213
5966
self,
@@ -72,62 +79,61 @@ async def callback( # noqa: C901, WPS213
7279
:param raise_err: raise an error if cannot save result in
7380
result_backend.
7481
"""
75-
async with self.sem:
76-
logger.debug(f"Received message: {message}")
77-
if message.task_name not in self.broker.available_tasks:
78-
logger.warning(
79-
'task "%s" is not found. Maybe you forgot to import it?',
80-
message.task_name,
81-
)
82-
return
83-
logger.debug(
84-
"Function for task %s is resolved. Executing...",
82+
logger.debug(f"Received message: {message}")
83+
if message.task_name not in self.broker.available_tasks:
84+
logger.warning(
85+
'task "%s" is not found. Maybe you forgot to import it?',
8586
message.task_name,
8687
)
87-
try:
88-
taskiq_msg = self.broker.formatter.loads(message=message)
89-
except Exception as exc:
90-
logger.warning(
91-
"Cannot parse message: %s. Skipping execution.\n %s",
92-
message,
93-
exc,
94-
exc_info=True,
88+
return
89+
logger.debug(
90+
"Function for task %s is resolved. Executing...",
91+
message.task_name,
92+
)
93+
try:
94+
taskiq_msg = self.broker.formatter.loads(message=message)
95+
except Exception as exc:
96+
logger.warning(
97+
"Cannot parse message: %s. Skipping execution.\n %s",
98+
message,
99+
exc,
100+
exc_info=True,
101+
)
102+
return
103+
for middleware in self.broker.middlewares:
104+
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
105+
taskiq_msg = await maybe_awaitable(
106+
middleware.pre_execute(
107+
taskiq_msg,
108+
),
95109
)
96-
return
97-
for middleware in self.broker.middlewares:
98-
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
99-
taskiq_msg = await maybe_awaitable(
100-
middleware.pre_execute(
101-
taskiq_msg,
102-
),
103-
)
104110

105-
logger.info(
106-
"Executing task %s with ID: %s",
107-
taskiq_msg.task_name,
108-
taskiq_msg.task_id,
109-
)
110-
result = await self.run_task(
111-
target=self.broker.available_tasks[message.task_name].original_func,
112-
message=taskiq_msg,
111+
logger.info(
112+
"Executing task %s with ID: %s",
113+
taskiq_msg.task_name,
114+
taskiq_msg.task_id,
115+
)
116+
result = await self.run_task(
117+
target=self.broker.available_tasks[message.task_name].original_func,
118+
message=taskiq_msg,
119+
)
120+
for middleware in self.broker.middlewares:
121+
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
122+
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
123+
try:
124+
await self.broker.result_backend.set_result(message.task_id, result)
125+
except Exception as exc:
126+
logger.exception(
127+
"Can't set result in result backend. Cause: %s",
128+
exc,
129+
exc_info=True,
113130
)
114-
for middleware in self.broker.middlewares:
115-
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
116-
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
117-
try:
118-
await self.broker.result_backend.set_result(message.task_id, result)
119-
except Exception as exc:
120-
logger.exception(
121-
"Can't set result in result backend. Cause: %s",
122-
exc,
123-
exc_info=True,
124-
)
125-
if raise_err:
126-
raise exc
131+
if raise_err:
132+
raise exc
127133

128-
for middleware in self.broker.middlewares:
129-
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
130-
await maybe_awaitable(middleware.post_save(taskiq_msg, result))
134+
for middleware in self.broker.middlewares:
135+
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
136+
await maybe_awaitable(middleware.post_save(taskiq_msg, result))
131137

132138
async def run_task( # noqa: C901, WPS210
133139
self,
@@ -232,11 +238,28 @@ async def listen(self) -> None: # pragma: no cover
232238
It uses listen() method of an AsyncBroker
233239
to get new messages from queues.
234240
"""
235-
logger.debug("Runing startup event.")
236241
await self.broker.startup()
237242
logger.info("Listening started.")
238243
tasks = set()
244+
245+
def task_cb(task: "asyncio.Task[Any]") -> None:
246+
"""
247+
Callback for tasks.
248+
249+
This function used to remove task
250+
from the list of active tasks and release
251+
the semaphore, so other tasks can use it.
252+
253+
:param task: finished task
254+
"""
255+
tasks.discard(task)
256+
if self.sem is not None:
257+
self.sem.release()
258+
239259
async for message in self.broker.listen():
260+
# Waits for semaphore to be released.
261+
if self.sem is not None:
262+
await self.sem.acquire()
240263
task = asyncio.create_task(self.callback(message=message, raise_err=False))
241264
tasks.add(task)
242265

@@ -245,4 +268,4 @@ async def listen(self) -> None: # pragma: no cover
245268
# Because python's GC can silently cancel task
246269
# and it considered to be Hisenbug.
247270
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
248-
task.add_done_callback(tasks.discard)
271+
task.add_done_callback(task_cb)

tests/cli/worker/test_receiver.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,42 @@
11
import asyncio
22
from concurrent.futures import ThreadPoolExecutor
3-
from typing import Any, Optional
3+
from typing import Any, AsyncGenerator, Callable, List, Optional, TypeVar
44

55
import pytest
66
from taskiq_dependencies import Depends
77

88
from taskiq.abc.broker import AsyncBroker
99
from taskiq.abc.middleware import TaskiqMiddleware
10+
from taskiq.abc.result_backend import AsyncResultBackend
1011
from taskiq.brokers.inmemory_broker import InMemoryBroker
1112
from taskiq.message import BrokerMessage, TaskiqMessage
1213
from taskiq.receiver import Receiver
1314
from taskiq.result import TaskiqResult
1415

16+
_T = TypeVar("_T")
17+
18+
19+
class BrokerForTests(InMemoryBroker):
20+
def __init__(
21+
self,
22+
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
23+
task_id_generator: Optional[Callable[[], str]] = None,
24+
) -> None:
25+
super().__init__(
26+
result_backend=result_backend,
27+
task_id_generator=task_id_generator,
28+
)
29+
self.to_send: "List[TaskiqMessage]" = []
30+
31+
async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
32+
for message in self.to_send:
33+
yield self.formatter.dumps(message)
34+
1535

1636
def get_receiver(
1737
broker: Optional[AsyncBroker] = None,
1838
no_parse: bool = False,
19-
max_async_tasks: int = 10,
39+
max_async_tasks: Optional[int] = None,
2040
) -> Receiver:
2141
"""
2242
Returns receiver with custom broker and args.
@@ -247,7 +267,8 @@ def test_func(tes_val: MyTestClass = Depends()) -> int:
247267
@pytest.mark.anyio
248268
async def test_callback_semaphore() -> None:
249269
"""Test that callback funcion semaphore works well."""
250-
broker = InMemoryBroker()
270+
max_async_tasks = 3
271+
broker = BrokerForTests()
251272
sem_num = 0
252273

253274
@broker.task
@@ -257,18 +278,23 @@ async def task_sem() -> int:
257278
await asyncio.sleep(1)
258279
return 1
259280

260-
receiver = get_receiver(broker, max_async_tasks=3)
261-
262-
broker_message = broker.formatter.dumps(
281+
broker.to_send = [
263282
TaskiqMessage(
264283
task_id="test_sem",
265284
task_name=task_sem.task_name,
266285
labels={},
267286
args=[],
268287
kwargs=[],
269-
),
270-
)
271-
tasks = [asyncio.create_task(receiver.callback(broker_message)) for _ in range(5)]
288+
)
289+
for _ in range(max_async_tasks + 2)
290+
]
291+
292+
# broker_message = broker.formatter.dumps(
293+
# )
294+
receiver = get_receiver(broker, max_async_tasks=3)
295+
296+
listen_task = asyncio.create_task(receiver.listen())
272297
await asyncio.sleep(0.3)
273-
assert sem_num == 3
274-
await asyncio.gather(*tasks)
298+
assert sem_num == max_async_tasks
299+
await listen_task
300+
assert sem_num == max_async_tasks + 2

0 commit comments

Comments
 (0)