Skip to content

Commit c33b758

Browse files
author
Anton
committed
feat: task's idle execution
1 parent e67dab3 commit c33b758

File tree

7 files changed

+302
-12
lines changed

7 files changed

+302
-12
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import asyncio
2+
from typing import AsyncGenerator
3+
4+
from taskiq import BrokerMessage
5+
from taskiq.brokers.inmemory_broker import InMemoryBroker
6+
7+
8+
class InMemoryQueueBroker(InMemoryBroker):
9+
"""In memory Broker based on asyncio.Queue."""
10+
11+
def __init__(
12+
self,
13+
sync_tasks_pool_size: int = 4,
14+
max_stored_results: int = 100,
15+
cast_types: bool = True,
16+
max_async_tasks: int = 30,
17+
) -> None:
18+
super().__init__(
19+
sync_tasks_pool_size,
20+
max_stored_results,
21+
cast_types,
22+
max_async_tasks,
23+
)
24+
self.queue: asyncio.Queue[BrokerMessage] = asyncio.Queue()
25+
26+
async def kick(self, message: BrokerMessage) -> None:
27+
"""
28+
Kicking task.
29+
30+
:param message: incoming message
31+
"""
32+
await self.queue.put(message)
33+
34+
async def listen(self) -> AsyncGenerator[bytes, None]:
35+
"""
36+
Listening for messages.
37+
38+
:yields: message's raw data
39+
"""
40+
running = asyncio.create_task(self.running.wait())
41+
42+
while not self.running.is_set():
43+
message = asyncio.create_task(self.queue.get())
44+
await asyncio.wait(
45+
[running, message],
46+
return_when=asyncio.FIRST_COMPLETED,
47+
)
48+
49+
if message.done():
50+
yield (await message).message
51+
continue
52+
53+
message.cancel()
54+
await running
55+
56+
async def startup(self) -> None:
57+
"""Runs startup events for client and worker side."""
58+
await super().startup()
59+
self.running = asyncio.Event()
60+
61+
async def shutdown(self) -> None:
62+
"""Runs shutdown events for client and worker side."""
63+
await super().shutdown()
64+
self.running.set()

taskiq/cli/worker/args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class WorkerArgs:
4141
receiver: str = "taskiq.receiver:Receiver"
4242
receiver_arg: List[Tuple[str, str]] = field(default_factory=list)
4343
max_prefetch: int = 0
44+
max_idle_tasks: Optional[int] = None
4445

4546
@classmethod
4647
def from_cli( # noqa: WPS213
@@ -176,6 +177,13 @@ def from_cli( # noqa: WPS213
176177
default=0,
177178
help="Maximum prefetched tasks per worker process. ",
178179
)
180+
parser.add_argument(
181+
"--max-idle-tasks",
182+
type=int,
183+
dest="max_idle_tasks",
184+
default=None,
185+
help="Maximum idle tasks per worker process. ",
186+
)
179187

180188
namespace = parser.parse_args(args)
181189
return WorkerArgs(**namespace.__dict__)

taskiq/cli/worker/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
133133
validate_params=not args.no_parse,
134134
max_async_tasks=args.max_async_tasks,
135135
max_prefetch=args.max_prefetch,
136+
max_idle_tasks=args.max_idle_tasks,
136137
**receiver_args,
137138
)
138139
loop.run_until_complete(receiver.listen())

taskiq/context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Awaitable, Callable
22

33
from taskiq.abc.broker import AsyncBroker
44
from taskiq.message import TaskiqMessage
@@ -10,8 +10,14 @@
1010
class Context:
1111
"""Context class."""
1212

13-
def __init__(self, message: TaskiqMessage, broker: AsyncBroker) -> None:
13+
def __init__(
14+
self,
15+
message: TaskiqMessage,
16+
broker: AsyncBroker,
17+
task_idler: Callable[[float], Awaitable[None]],
18+
) -> None:
1419
self.message = message
1520
self.broker = broker
1621
self.state: "TaskiqState" = None # type: ignore
1722
self.state = broker.state
23+
self.task_idler = task_idler

taskiq/receiver/receiver.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from taskiq.receiver.params_parser import parse_params
1616
from taskiq.result import TaskiqResult
1717
from taskiq.state import TaskiqState
18-
from taskiq.utils import maybe_awaitable
18+
from taskiq.utils import DequeQueue, DequeSemaphore, maybe_awaitable
1919

2020
logger = getLogger(__name__)
2121
QUEUE_DONE = b"-1"
22+
QUEUE_SKIP = b"-2"
2223

2324

2425
def _run_sync(target: Callable[..., Any], message: TaskiqMessage) -> Any:
@@ -45,6 +46,7 @@ def __init__( # noqa: WPS211
4546
validate_params: bool = True,
4647
max_async_tasks: "Optional[int]" = None,
4748
max_prefetch: int = 0,
49+
max_idle_tasks: Optional[int] = None,
4850
) -> None:
4951
self.broker = broker
5052
self.executor = executor
@@ -56,15 +58,20 @@ def __init__( # noqa: WPS211
5658
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
5759
self.task_hints[task.task_name] = get_type_hints(task.original_func)
5860
self.dependency_graphs[task.task_name] = DependencyGraph(task.original_func)
59-
self.sem: "Optional[asyncio.Semaphore]" = None
61+
self.sem: "Optional[DequeSemaphore]" = None
6062
if max_async_tasks is not None and max_async_tasks > 0:
61-
self.sem = asyncio.Semaphore(max_async_tasks)
63+
self.sem = DequeSemaphore(max_async_tasks)
6264
else:
6365
logger.warning(
6466
"Setting unlimited number of async tasks "
6567
+ "can result in undefined behavior",
6668
)
67-
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
69+
self.sem_prefetch = DequeSemaphore(max_prefetch)
70+
self.queue: DequeQueue[bytes] = DequeQueue()
71+
72+
self.sem_idle: Optional[asyncio.Semaphore] = None
73+
if max_idle_tasks and max_idle_tasks > 0:
74+
self.sem_idle = asyncio.Semaphore(max_idle_tasks)
6875

6976
async def callback( # noqa: C901, WPS213
7077
self,
@@ -176,7 +183,7 @@ async def run_task( # noqa: C901, WPS210
176183
broker_ctx = self.broker.custom_dependency_context
177184
broker_ctx.update(
178185
{
179-
Context: Context(message, self.broker),
186+
Context: Context(message, self.broker, self.task_idler),
180187
TaskiqState: self.broker.state,
181188
},
182189
)
@@ -243,11 +250,10 @@ async def listen(self) -> None: # pragma: no cover
243250
"""
244251
await self.broker.startup()
245252
logger.info("Listening started.")
246-
queue: asyncio.Queue[bytes] = asyncio.Queue()
247253

248254
async with anyio.create_task_group() as gr:
249-
gr.start_soon(self.prefetcher, queue)
250-
gr.start_soon(self.runner, queue)
255+
gr.start_soon(self.prefetcher, self.queue)
256+
gr.start_soon(self.runner, self.queue)
251257

252258
async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
253259
"""
@@ -268,7 +274,7 @@ async def prefetcher(self, queue: "asyncio.Queue[Any]") -> None:
268274

269275
await queue.put(QUEUE_DONE)
270276

271-
async def runner(self, queue: "asyncio.Queue[bytes]") -> None:
277+
async def runner(self, queue: "asyncio.Queue[bytes]") -> None: # noqa: C901, WPS213
272278
"""
273279
Run tasks.
274280
@@ -299,6 +305,15 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
299305
message = await queue.get()
300306
if message is QUEUE_DONE:
301307
break
308+
if message is QUEUE_SKIP:
309+
# Decrease max_prefetch
310+
prefetch_dec = asyncio.create_task(self.sem_prefetch.acquire_first())
311+
prefetch_dec.add_done_callback(tasks.discard)
312+
tasks.add(prefetch_dec)
313+
314+
if self.sem is not None:
315+
self.sem.release()
316+
continue
302317

303318
task = asyncio.create_task(
304319
self.callback(message=message, raise_err=False),
@@ -311,3 +326,32 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
311326
# and it considered to be Hisenbug.
312327
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
313328
task.add_done_callback(task_cb)
329+
330+
async def task_idler(self, wait: float) -> None:
331+
"""
332+
Temporary increasing `max_async_tasks` for at least `wait` amount of time.
333+
334+
:param wait: time
335+
"""
336+
if not self.sem:
337+
await asyncio.sleep(wait)
338+
return
339+
340+
if not self.sem_idle:
341+
logger.warning("`max_idle_tasks` is undefined. Idle is unavailable.")
342+
await asyncio.sleep(wait)
343+
return
344+
345+
start_time = time()
346+
async with self.sem_idle:
347+
# Increase max_tasks
348+
# Increase max_prefetch in runner
349+
self.sem.release()
350+
351+
# Wait
352+
await asyncio.sleep(wait - (time() - start_time))
353+
354+
# Decrease max_prefetch in runner
355+
await self.queue.put_first(QUEUE_SKIP)
356+
# Decrease max_tasks
357+
await self.sem.acquire_first()

taskiq/utils.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
12
import inspect
2-
from typing import Any, Coroutine, TypeVar, Union
3+
from typing import TYPE_CHECKING, Any, Coroutine, Deque, Generic, TypeVar, Union
4+
5+
from typing_extensions import Literal
36

47
_T = TypeVar("_T") # noqa: WPS111
58

@@ -35,3 +38,103 @@ def remove_suffix(text: str, suffix: str) -> str:
3538
if text.endswith(suffix):
3639
return text[: -len(suffix)]
3740
return text
41+
42+
43+
class DequeSemaphore(asyncio.Semaphore):
44+
"""Deque based Semaphore."""
45+
46+
if TYPE_CHECKING: # noqa: WPS604
47+
_loop: asyncio.BaseEventLoop
48+
49+
async def acquire_first(self) -> Literal[True]:
50+
"""
51+
Acquire as soon as possible. LIFO style.
52+
53+
:raises BaseException: exception
54+
:return: true
55+
"""
56+
self._value -= 1
57+
58+
while self._value < 0:
59+
fut: asyncio.Future[Any] = self._loop.create_future()
60+
self._waiters.appendleft(fut)
61+
try:
62+
await fut
63+
except BaseException: # noqa: WPS424
64+
self._value += 1
65+
66+
fut.cancel()
67+
if not self.locked() and not fut.cancelled():
68+
self._wake_up_next()
69+
raise
70+
71+
return True
72+
73+
74+
class DequeQueue(
75+
asyncio.Queue, # type: ignore
76+
Generic[_T],
77+
):
78+
"""Deque based Queue."""
79+
80+
if TYPE_CHECKING: # noqa: WPS604
81+
_loop: asyncio.BaseEventLoop
82+
_queue: Deque[_T]
83+
_putters: Deque[Any]
84+
_getters: Deque[Any]
85+
_unfinished_tasks: int
86+
_finished: asyncio.Event
87+
_wakeup_next: Any
88+
89+
async def put_first(self, item: _T) -> None:
90+
"""
91+
Wait till queue is not full. Put item in Queue as soon as possible. LIFO style.
92+
93+
:param item: value to prepend
94+
:raises BaseException: something goes wrong
95+
:returns: nothing
96+
"""
97+
while self.full():
98+
putter = self._loop.create_future()
99+
self._putters.appendleft(putter)
100+
try:
101+
await putter
102+
except BaseException: # noqa: WPS424
103+
putter.cancel() # Just in case putter is not done yet.
104+
try: # noqa: WPS505
105+
# Clean self._putters from canceled putters.
106+
self._putters.remove(putter)
107+
except ValueError:
108+
# The putter could be removed from self._putters by a
109+
# previous get_nowait call.
110+
pass # noqa: WPS420
111+
if not self.full() and not putter.cancelled():
112+
# We were woken up by get_nowait(), but can't take
113+
# the call. Wake up the next in line.
114+
self._wakeup_next(self._putters)
115+
raise
116+
117+
return self.put_first_nowait(item)
118+
119+
def put_first_nowait(self, item: _T) -> None:
120+
"""
121+
Put item in Queue as soon as possible. LIFO style.
122+
123+
:param item: value to prepend
124+
:raises QueueFull: queue is full
125+
"""
126+
if self.full():
127+
raise asyncio.QueueFull()
128+
129+
self._put_first(item)
130+
self._unfinished_tasks += 1
131+
self._finished.clear()
132+
self._wakeup_next(self._getters)
133+
134+
def _put_first(self, item: _T) -> None:
135+
"""
136+
Prepend item.
137+
138+
:param item: value to prepend
139+
"""
140+
self._queue.appendleft(item)

0 commit comments

Comments
 (0)