Skip to content

Commit 4c008df

Browse files
committed
Added middlewares.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent b124c06 commit 4c008df

File tree

8 files changed

+175
-60
lines changed

8 files changed

+175
-60
lines changed

taskiq/abc/broker.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from abc import ABC, abstractmethod
22
from functools import wraps
33
from logging import getLogger
4-
from typing import (
4+
from typing import ( # noqa: WPS235
5+
TYPE_CHECKING,
56
Any,
67
AsyncGenerator,
78
Callable,
89
Dict,
10+
List,
911
Optional,
1012
TypeVar,
1113
Union,
@@ -14,13 +16,16 @@
1416

1517
from typing_extensions import ParamSpec
1618

17-
from taskiq.abc.plugins.formatter import TaskiqFormatter
18-
from taskiq.abc.result_backend import AsyncResultBackend
1919
from taskiq.decor import AsyncTaskiqDecoratedTask
2020
from taskiq.message import BrokerMessage
2121
from taskiq.plugins.json_formatter import JSONFormatter
2222
from taskiq.result_backends.dummy import DummyResultBackend
2323

24+
if TYPE_CHECKING:
25+
from taskiq.abc.formatter import TaskiqFormatter
26+
from taskiq.abc.middleware import TaskiqMiddleware
27+
from taskiq.abc.result_backend import AsyncResultBackend
28+
2429
_T = TypeVar("_T") # noqa: WPS111
2530
_FuncParams = ParamSpec("_FuncParams")
2631
_ReturnType = TypeVar("_ReturnType")
@@ -41,14 +46,28 @@ class AsyncBroker(ABC):
4146

4247
def __init__(
4348
self,
44-
result_backend: Optional[AsyncResultBackend[_T]] = None,
49+
result_backend: "Optional[AsyncResultBackend[_T]]" = None,
4550
) -> None:
4651
if result_backend is None:
4752
result_backend = DummyResultBackend()
53+
self.middlewares: "List[TaskiqMiddleware]" = []
4854
self.result_backend = result_backend
4955
self.is_worker_process = False
5056
self.decorator_class = AsyncTaskiqDecoratedTask
51-
self.formatter: TaskiqFormatter = JSONFormatter()
57+
self.formatter: "TaskiqFormatter" = JSONFormatter()
58+
59+
def add_middlewares(self, middlewares: "List[TaskiqMiddleware]") -> None:
60+
"""
61+
Add a list of middlewares.
62+
63+
You should call this method to set middlewares,
64+
since it saves current broker in all middlewares.
65+
66+
:param middlewares: list of middlewares.
67+
"""
68+
for middleware in middlewares:
69+
middleware.set_broker(self)
70+
self.middlewares.append(middleware)
5271

5372
async def startup(self) -> None:
5473
"""Do something when starting broker."""
File renamed without changes.

taskiq/abc/middleware.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Union
2+
3+
if TYPE_CHECKING:
4+
from taskiq.abc.broker import AsyncBroker
5+
from taskiq.message import TaskiqMessage
6+
from taskiq.result import TaskiqResult
7+
8+
9+
class TaskiqMiddleware:
10+
"""Base class for middlewares."""
11+
12+
def __init__(self) -> None:
13+
self.broker: "AsyncBroker" = None # type: ignore
14+
15+
def set_broker(self, broker: "AsyncBroker") -> None:
16+
"""
17+
Sets broker to middleware.
18+
19+
:param broker: broker to set.
20+
"""
21+
self.broker = broker
22+
23+
def pre_send(
24+
self,
25+
message: "TaskiqMessage",
26+
labels: Dict[str, Any],
27+
) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
28+
"""
29+
Hook that executes before sending the task to worker.
30+
31+
This is a client-side hook, that executes right before
32+
the message is sent to broker.
33+
34+
:param message: message to send.
35+
:param labels: task's labels.
36+
:return: modified message.
37+
"""
38+
return message
39+
40+
def post_send(
41+
self,
42+
message: "TaskiqMessage",
43+
labels: Dict[str, Any],
44+
) -> "Union[None, Coroutine[Any, Any, None]]":
45+
"""
46+
This hook is executed right after the task is sent.
47+
48+
This is a client-side hook. It executes right
49+
after the messages is kicked in broker.
50+
51+
:param message: kicked message.
52+
:param labels: labels for a message.
53+
"""
54+
55+
def pre_execute(
56+
self,
57+
message: "TaskiqMessage",
58+
labels: Dict[str, Any],
59+
) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
60+
"""
61+
This hook is called before executing task.
62+
63+
This is a worker-side hook, wich means it
64+
executes in the worker process.
65+
66+
:param message: incoming parsed taskiq message.
67+
:param labels: task's labels without user-supplied lables.
68+
:return: modified message.
69+
"""
70+
return message
71+
72+
def post_execute(
73+
self,
74+
result: "TaskiqResult[Any]",
75+
labels: Dict[str, Any],
76+
) -> "Union[None, Coroutine[Any, Any, None]]":
77+
"""
78+
This hook executes after task is complete.
79+
80+
This is a worker-side hook. It's called
81+
in worker process.
82+
83+
:param result: result of execution for current task.
84+
:param labels: task's labels. Without user-supplied labels.
85+
"""

taskiq/abc/result_backend.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Generic, TypeVar
33

44
from taskiq.result import TaskiqResult
5-
from taskiq.task import AsyncTaskiqTask
65

76
_ReturnType = TypeVar("_ReturnType")
87

@@ -16,19 +15,6 @@ async def startup(self) -> None:
1615
async def shutdown(self) -> None:
1716
"""Do something on shutdown."""
1817

19-
def generate_task(self, task_id: str) -> "AsyncTaskiqTask[_ReturnType]":
20-
"""
21-
Generates new task.
22-
23-
This function creates new AsyncTaskiqTask
24-
that returned to client after calling kiq
25-
method.
26-
27-
:param task_id: id of a task to save.
28-
:return: task object.
29-
"""
30-
return AsyncTaskiqTask(task_id=task_id, result_backend=self)
31-
3218
@abstractmethod
3319
async def set_result(self, task_id: str, result: TaskiqResult[_ReturnType]) -> None:
3420
"""

taskiq/brokers/shared_broker.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,9 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]:
2121
our shared broker and send task to it, instead
2222
of shared_broker.
2323
24-
:raises TaskiqError: if _default_broker is not set.
2524
:return: new kicker.
2625
"""
27-
broker = getattr(self.broker, "_default_broker", None)
28-
if broker is None:
29-
raise TaskiqError(
30-
"You cannot use kiq directly on shared task "
31-
"without setting the default_broker.",
32-
)
26+
broker = getattr(self.broker, "_default_broker", None) or self.broker
3327
return AsyncKicker(
3428
task_name=self.task_name,
3529
broker=broker,
@@ -45,15 +39,6 @@ def __init__(self) -> None:
4539
self._default_broker: Optional[AsyncBroker] = None
4640
self.decorator_class = SharedDecoratedTask
4741

48-
async def kick(self, message: BrokerMessage) -> None:
49-
"""
50-
Shared broker cannot kick tasks.
51-
52-
:param message: message to send.
53-
:raises TaskiqError: if called.
54-
"""
55-
raise TaskiqError("Shared broker cannot kick tasks.")
56-
5742
def default_broker(self, new_broker: AsyncBroker) -> None:
5843
"""
5944
Updates default broker.
@@ -62,6 +47,18 @@ def default_broker(self, new_broker: AsyncBroker) -> None:
6247
"""
6348
self._default_broker = new_broker
6449

50+
async def kick(self, message: BrokerMessage) -> None:
51+
"""
52+
Shared broker cannot kick tasks.
53+
54+
:param message: message to send.
55+
:raises TaskiqError: if called.
56+
"""
57+
raise TaskiqError(
58+
"You cannot use kiq directly on shared task "
59+
"without setting the default_broker.",
60+
)
61+
6562
async def listen(self) -> AsyncGenerator[BrokerMessage, None]: # type: ignore
6663
"""
6764
Shared broker cannot listen to tasks.

taskiq/cli/async_task_runner.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,29 @@ async def async_listen_messages( # noqa: C901, WPS210, WPS213
276276
exc_info=True,
277277
)
278278
continue
279+
for middleware in broker.middlewares:
280+
pre_ex_res = middleware.pre_execute(
281+
taskiq_msg,
282+
broker.available_tasks[message.task_name].labels,
283+
)
284+
if inspect.isawaitable(pre_ex_res):
285+
taskiq_msg = await pre_ex_res
286+
else:
287+
taskiq_msg = pre_ex_res # type: ignore
279288
result = await run_task(
280289
broker.available_tasks[message.task_name].original_func,
281290
task_signatures.get(message.task_name),
282291
taskiq_msg,
283292
cli_args.log_collector_format,
284293
executor,
285294
)
295+
for middleware in broker.middlewares:
296+
post_ex_res = middleware.post_execute(
297+
result,
298+
broker.available_tasks[message.task_name].labels,
299+
)
300+
if inspect.isawaitable(post_ex_res):
301+
await post_ex_res
286302
try:
287303
await broker.result_backend.set_result(message.task_id, result)
288304
except Exception as exc:

taskiq/kicker.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
from dataclasses import asdict, is_dataclass
2+
from inspect import isawaitable
23
from logging import getLogger
3-
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Generic, TypeVar, overload
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
Coroutine,
8+
Dict,
9+
Generic,
10+
TypeVar,
11+
Union,
12+
overload,
13+
)
414
from uuid import uuid4
515

616
from pydantic import BaseModel
@@ -27,38 +37,27 @@ def __init__(
2737
self,
2838
task_name: str,
2939
broker: "AsyncBroker",
30-
labels: Dict[str, Any],
40+
labels: Dict[
41+
str,
42+
Union[
43+
str,
44+
int,
45+
float,
46+
],
47+
],
3148
) -> None:
3249
self.task_name = task_name
3350
self.broker = broker
3451
self.labels = labels
3552

36-
def with_label(
37-
self,
38-
label_name: str,
39-
value: Any,
40-
) -> "AsyncKicker[_FuncParams, _ReturnType]":
41-
"""
42-
Update one single label.
43-
44-
This method is used to update
45-
task's labels before sending.
46-
47-
:param label_name: name of the label to update.
48-
:param value: label's value.
49-
:return: kicker object with new labels.
50-
"""
51-
self.labels[label_name] = value
52-
return self
53-
5453
def with_labels(
5554
self,
56-
labels: Dict[str, Any],
55+
**labels: Union[str, int, float],
5756
) -> "AsyncKicker[_FuncParams, _ReturnType]":
5857
"""
5958
Update function's labels before sending.
6059
61-
:param labels: dict with new labels.
60+
:param labels: new labels.
6261
:return: kicker with new labels.
6362
"""
6463
self.labels.update(labels)
@@ -96,7 +95,7 @@ async def kiq( # noqa: D102
9695
) -> AsyncTaskiqTask[_ReturnType]:
9796
...
9897

99-
async def kiq(
98+
async def kiq( # noqa: C901
10099
self,
101100
*args: _FuncParams.args,
102101
**kwargs: _FuncParams.kwargs,
@@ -118,11 +117,24 @@ async def kiq(
118117
f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.",
119118
)
120119
message = self._prepare_message(*args, **kwargs)
120+
for middleware in self.broker.middlewares:
121+
pre_send_res = middleware.pre_send(message, self.labels)
122+
if isawaitable(pre_send_res):
123+
message = await pre_send_res
124+
else:
125+
message = pre_send_res # type: ignore
121126
try:
122127
await self.broker.kick(self.broker.formatter.dumps(message, self.labels))
123128
except Exception as exc:
124129
raise SendTaskError() from exc
125-
return self.broker.result_backend.generate_task(message.task_id)
130+
for middleware in self.broker.middlewares:
131+
post_send_res = middleware.post_send(message, self.labels)
132+
if isawaitable(post_send_res):
133+
await post_send_res
134+
return AsyncTaskiqTask(
135+
task_id=message.task_id,
136+
result_backend=self.broker.result_backend,
137+
)
126138

127139
@classmethod
128140
def _prepare_arg(cls, arg: Any) -> Any:

taskiq/plugins/json_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Dict
22

3-
from taskiq.abc.plugins.formatter import TaskiqFormatter
3+
from taskiq.abc.formatter import TaskiqFormatter
44
from taskiq.message import BrokerMessage, TaskiqMessage
55

66

0 commit comments

Comments
 (0)