Skip to content

Commit dcc00ac

Browse files
author
Anton
committed
fix: more tests, better lock
1 parent 1774ec5 commit dcc00ac

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

taskiq/receiver/receiver.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
330330
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
331331
task.add_done_callback(task_cb)
332332

333-
async def task_idler(self, wait: float) -> None:
333+
async def task_idler(self, wait: float) -> None: # noqa: WPS213, WPS217
334334
"""
335335
Temporary increasing `max_async_tasks` for at least `wait` amount of time.
336336
@@ -346,7 +346,13 @@ async def task_idler(self, wait: float) -> None:
346346
return
347347

348348
start_time = time()
349-
async with self.sem_idle:
349+
with anyio.move_on_after(wait) as scope:
350+
await self.sem_idle.acquire()
351+
352+
if scope.cancel_called: # noqa: WPS441
353+
return
354+
355+
try: # noqa: WPS501
350356
# Increase max_tasks
351357
# Increase max_prefetch in runner
352358
self.sem.release()
@@ -357,6 +363,8 @@ async def task_idler(self, wait: float) -> None:
357363
# Decrease max_prefetch in runner
358364
task = asyncio.create_task(self.queue.put_first(QUEUE_SKIP))
359365
# Decrease max_tasks
360-
361366
await self.sem.acquire_first()
362367
await task
368+
369+
finally:
370+
self.sem_idle.release()

tests/cli/worker/test_receiver.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from concurrent.futures import ThreadPoolExecutor
33
from typing import Any, AsyncGenerator, List, Optional, TypeVar
44

5+
import anyio
56
import pytest
67
from taskiq_dependencies import Depends
78

@@ -388,3 +389,30 @@ async def wait_for_task(
388389

389390
await broker.shutdown()
390391
await listen_task
392+
393+
394+
@pytest.mark.anyio
395+
async def test_tasks_sleep() -> None:
396+
""""""
397+
broker = InMemoryQueueBroker()
398+
399+
@broker.task
400+
async def task_run(ind: int, ctx: Context = Depends()) -> int:
401+
await ctx.task_idler(0.1)
402+
return ind
403+
404+
receiver = get_receiver(broker, max_async_tasks=1, max_idle_tasks=20)
405+
listen_task = asyncio.create_task(receiver.listen())
406+
407+
with anyio.fail_after(1):
408+
tasks_tasks = [asyncio.create_task(task_run.kiq(ind)) for ind in range(100)]
409+
tasks = await asyncio.gather(*tasks_tasks)
410+
resps_tasks = [
411+
asyncio.create_task(task.wait_result(timeout=1)) for task in tasks
412+
]
413+
resps = await asyncio.gather(*resps_tasks)
414+
value = [resp.return_value for resp in resps]
415+
assert value == list(range(100))
416+
417+
await broker.shutdown()
418+
await listen_task

0 commit comments

Comments
 (0)