Skip to content

Commit 4728274

Browse files
authored
Merge pull request #15 from taskiq-python/feature/inmemory-broker
Added inmemory broker.
2 parents 030c0b5 + 5e88008 commit 4728274

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

taskiq/brokers/inmemory_broker.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import inspect
2+
from collections import OrderedDict
3+
from concurrent.futures import ThreadPoolExecutor
4+
from typing import AsyncGenerator, Optional, TypeVar
5+
6+
from taskiq.abc.broker import AsyncBroker
7+
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
8+
from taskiq.cli.async_task_runner import run_task
9+
from taskiq.message import TaskiqMessage
10+
11+
_ReturnType = TypeVar("_ReturnType")
12+
13+
14+
class InmemoryResultBackend(AsyncResultBackend[_ReturnType]):
15+
"""
16+
Inmemory result backend.
17+
18+
This resultbackend is intended to be used only
19+
with inmemory broker.
20+
21+
It stores all results in a dict in memory.
22+
"""
23+
24+
def __init__(self, max_stored_results: int = 100) -> None:
25+
self.max_stored_results = max_stored_results
26+
self.results: OrderedDict[str, TaskiqResult[_ReturnType]] = OrderedDict()
27+
28+
async def set_result(self, task_id: str, result: TaskiqResult[_ReturnType]) -> None:
29+
"""
30+
Sets result.
31+
32+
This method is used to store result of an execution in a
33+
results dict. But also it removes previous results
34+
to keep memory footprint as low as possible.
35+
36+
:param task_id: id of a task.
37+
:param result: result of an execution.
38+
"""
39+
if self.max_stored_results != -1:
40+
if len(self.results) >= self.max_stored_results:
41+
self.results.popitem(last=False)
42+
self.results[task_id] = result
43+
44+
async def is_result_ready(self, task_id: str) -> bool:
45+
"""
46+
Checks wether result is ready.
47+
48+
Readiness means that result with this task_id is
49+
present in results dict.
50+
51+
:param task_id: id of a task to check.
52+
:return: True if ready.
53+
"""
54+
return task_id in self.results
55+
56+
async def get_result(
57+
self,
58+
task_id: str,
59+
with_logs: bool = False,
60+
) -> TaskiqResult[_ReturnType]:
61+
"""
62+
Get result of a task.
63+
64+
This method is used to get result
65+
from result dict.
66+
67+
It throws exception in case if
68+
result dict doesn't have a value
69+
for task_id.
70+
71+
:param task_id: id of a task.
72+
:param with_logs: this option is ignored.
73+
:return: result of a task execution.
74+
"""
75+
return self.results[task_id]
76+
77+
78+
class InMemoryBroker(AsyncBroker):
79+
"""
80+
This broker is used to execute tasks without sending them elsewhere.
81+
82+
It's useful for local development, if you don't want to setup real broker.
83+
"""
84+
85+
def __init__(
86+
self,
87+
sync_tasks_pool_size: int = 4,
88+
logs_format: Optional[str] = None,
89+
max_stored_results: int = 100,
90+
) -> None:
91+
super().__init__(
92+
InmemoryResultBackend(
93+
max_stored_results=max_stored_results,
94+
),
95+
)
96+
# We mock as if it's a worker process.
97+
# So every task call will add tasks in
98+
# _related_tasks attribute.
99+
self.is_worker_process = True
100+
self.tasks_mapping = None
101+
self.executor = ThreadPoolExecutor(max_workers=sync_tasks_pool_size)
102+
if logs_format is None:
103+
logs_format = "%(levelname)s %(message)s"
104+
self.logs_format = logs_format
105+
106+
async def kick(self, message: TaskiqMessage) -> None:
107+
"""
108+
Kicking task.
109+
110+
This method just executes given task.
111+
112+
:param message: incomming message.
113+
:raises ValueError: if someone wants to kick unknown task.
114+
"""
115+
for task in self._related_tasks:
116+
if task.task_name == message.task_name:
117+
target_task = task
118+
if target_task is None:
119+
raise ValueError("Unknown task.")
120+
result = await run_task(
121+
target=target_task.original_func,
122+
signature=inspect.signature(target_task.original_func),
123+
message=message,
124+
log_collector_format=self.logs_format,
125+
executor=self.executor,
126+
)
127+
await self.result_backend.set_result(message.task_id, result)
128+
129+
async def listen(self) -> AsyncGenerator[TaskiqMessage, None]: # type: ignore
130+
"""
131+
Inmemory broker cannot listen.
132+
133+
This method throws RuntimeError if you call it.
134+
Because inmemory broker cannot really listen to any of tasks.
135+
136+
:raises RuntimeError: if this method is called.
137+
"""
138+
raise RuntimeError("Inmemory brokers cannot listen.")

taskiq/cli/async_task_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def run_task( # noqa: WPS210
101101
target: Callable[..., Any],
102102
signature: Optional[inspect.Signature],
103103
message: TaskiqMessage,
104-
cli_args: TaskiqArgs,
104+
log_collector_format: str,
105105
executor: Optional[Executor] = None,
106106
) -> TaskiqResult[Any]:
107107
"""
@@ -121,7 +121,7 @@ async def run_task( # noqa: WPS210
121121
:param target: function to execute.
122122
:param signature: signature of an original function.
123123
:param message: received message.
124-
:param cli_args: CLI arguments for worker.
124+
:param log_collector_format: Log format in wich logs are collected.
125125
:param executor: executor to run sync tasks.
126126
:return: result of execution.
127127
"""
@@ -131,7 +131,7 @@ async def run_task( # noqa: WPS210
131131
returned = None
132132
# Captures function's logs.
133133
parse_params(signature, message)
134-
with LogsCollector(logs, cli_args.log_collector_format):
134+
with LogsCollector(logs, log_collector_format):
135135
start_time = time()
136136
try:
137137
if asyncio.iscoroutinefunction(target):
@@ -273,7 +273,7 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
273273
func,
274274
task_signatures.get(message.task_name),
275275
message,
276-
cli_args,
276+
cli_args.log_collector_format,
277277
executor,
278278
)
279279
try:

0 commit comments

Comments
 (0)