1
1
import asyncio
2
+ from datetime import timedelta
2
3
from logging import getLogger
3
- from typing import Any , Callable , Coroutine , Optional , TypeVar
4
+ from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar
4
5
5
- from aio_pika import ExchangeType , Message , connect_robust
6
+ from aio_pika import DeliveryMode , ExchangeType , Message , connect_robust
6
7
from aio_pika .abc import (
7
8
AbstractChannel ,
8
9
AbstractIncomingMessage ,
17
18
logger = getLogger ("taskiq.aio_pika_broker" )
18
19
19
20
21
+ def parse_val (
22
+ parse_func : Callable [[str ], _T ],
23
+ target : Optional [str ] = None ,
24
+ ) -> Optional [_T ]:
25
+ """
26
+ Parse string to some value.
27
+
28
+ :param parse_func: function to use if value is present.
29
+ :param target: value to parse, defaults to None
30
+ :return: Optional value.
31
+ """
32
+ if target is None :
33
+ return None
34
+
35
+ try :
36
+ return parse_func (target )
37
+ except ValueError :
38
+ return None
39
+
40
+
20
41
class AioPikaBroker (AsyncBroker ):
21
42
"""Broker that works with RabbitMQ."""
22
43
@@ -29,9 +50,12 @@ def __init__( # noqa: WPS211
29
50
loop : Optional [asyncio .AbstractEventLoop ] = None ,
30
51
exchange_name : str = "taskiq" ,
31
52
queue_name : str = "taskiq" ,
53
+ dead_letter_queue_name : Optional [str ] = None ,
54
+ delay_queue_name : Optional [str ] = None ,
32
55
declare_exchange : bool = True ,
33
56
routing_key : str = "#" ,
34
57
exchange_type : ExchangeType = ExchangeType .TOPIC ,
58
+ max_priority : Optional [int ] = None ,
35
59
** connection_kwargs : Any ,
36
60
) -> None :
37
61
"""
@@ -46,11 +70,16 @@ def __init__( # noqa: WPS211
46
70
:param loop: specific even loop.
47
71
:param exchange_name: name of exchange that used to send messages.
48
72
:param queue_name: queue that used to get incoming messages.
73
+ :param dead_letter_queue_name: custom name for dead-letter queue.
74
+ by default it set to {queue_name}.dead_letter.
75
+ :param delay_queue_name: custom name for queue that used to
76
+ deliver messages with delays.
49
77
:param declare_exchange: whether you want to declare new exchange
50
78
if it doesn't exist.
51
79
:param routing_key: that used to bind that queue to the exchange.
52
80
:param exchange_type: type of the exchange.
53
81
Used only if `declare_exchange` is True.
82
+ :param max_priority: maximum priority value for messages.
54
83
:param connection_kwargs: additional keyword arguments,
55
84
for connect_robust method of aio-pika.
56
85
"""
@@ -65,6 +94,16 @@ def __init__( # noqa: WPS211
65
94
self ._declare_exchange = declare_exchange
66
95
self ._queue_name = queue_name
67
96
self ._routing_key = routing_key
97
+ self ._max_priority = max_priority
98
+
99
+ self ._dead_letter_queue_name = f"{ queue_name } .dead_letter"
100
+ if dead_letter_queue_name :
101
+ self ._dead_letter_queue_name = dead_letter_queue_name
102
+
103
+ self ._delay_queue_name = f"{ queue_name } .delay"
104
+ if delay_queue_name :
105
+ self ._delay_queue_name = delay_queue_name
106
+
68
107
self .read_conn : Optional [AbstractRobustConnection ] = None
69
108
self .write_conn : Optional [AbstractRobustConnection ] = None
70
109
self .write_channel : Optional [AbstractChannel ] = None
@@ -97,7 +136,26 @@ async def startup(self) -> None: # noqa: WPS217
97
136
self ._exchange_name ,
98
137
ensure = False ,
99
138
)
100
- queue = await self .write_channel .declare_queue (self ._queue_name )
139
+ await self .write_channel .declare_queue (
140
+ self ._dead_letter_queue_name ,
141
+ )
142
+ args : "Dict[str, Any]" = {
143
+ "x-dead-letter-exchange" : "" ,
144
+ "x-dead-letter-routing-key" : self ._dead_letter_queue_name ,
145
+ }
146
+ if self ._max_priority is not None :
147
+ args ["x-max-priority" ] = self ._max_priority
148
+ queue = await self .write_channel .declare_queue (
149
+ self ._queue_name ,
150
+ arguments = args ,
151
+ )
152
+ await self .write_channel .declare_queue (
153
+ self ._delay_queue_name ,
154
+ arguments = {
155
+ "x-dead-letter-exchange" : "" ,
156
+ "x-dead-letter-routing-key" : self ._queue_name ,
157
+ },
158
+ )
101
159
await queue .bind (exchange = exchange , routing_key = self ._routing_key )
102
160
103
161
async def kick (self , message : BrokerMessage ) -> None :
@@ -111,25 +169,35 @@ async def kick(self, message: BrokerMessage) -> None:
111
169
in headers. And message's routing key is the same
112
170
as the task_name.
113
171
114
-
115
172
:raises ValueError: if startup wasn't called.
116
173
:param message: message to send.
117
174
"""
118
175
if self .write_channel is None :
119
176
raise ValueError ("Please run startup before kicking." )
177
+ priority = parse_val (int , message .labels .get ("priority" ))
120
178
rmq_msg = Message (
121
179
body = message .message .encode (),
122
180
headers = {
123
181
"task_id" : message .task_id ,
124
182
"task_name" : message .task_name ,
125
183
** message .labels ,
126
184
},
185
+ delivery_mode = DeliveryMode .PERSISTENT ,
186
+ priority = priority ,
127
187
)
128
- exchange = await self .write_channel .get_exchange (
129
- self ._exchange_name ,
130
- ensure = False ,
131
- )
132
- await exchange .publish (rmq_msg , routing_key = message .task_name )
188
+ delay = parse_val (int , message .labels .get ("delay" ))
189
+ if delay is None :
190
+ exchange = await self .write_channel .get_exchange (
191
+ self ._exchange_name ,
192
+ ensure = False ,
193
+ )
194
+ await exchange .publish (rmq_msg , routing_key = message .task_name )
195
+ else :
196
+ rmq_msg .expiration = timedelta (seconds = delay )
197
+ await self .write_channel .default_exchange .publish (
198
+ rmq_msg ,
199
+ routing_key = self ._delay_queue_name ,
200
+ )
133
201
134
202
async def listen (
135
203
self ,
@@ -178,15 +246,18 @@ async def process_message(self, message: AbstractIncomingMessage) -> None:
178
246
:param message: received message.
179
247
"""
180
248
async with message .process ():
249
+ headers = {}
250
+ for header_name , header_value in message .headers .items ():
251
+ headers [header_name ] = str (header_value )
181
252
try :
182
253
broker_message = BrokerMessage (
183
- task_id = message . headers .pop ("task_id" ),
184
- task_name = message . headers .pop ("task_name" ),
254
+ task_id = headers .pop ("task_id" ),
255
+ task_name = headers .pop ("task_name" ),
185
256
message = message .body ,
186
- labels = message . headers ,
257
+ labels = headers ,
187
258
)
188
259
except (ValueError , LookupError ) as exc :
189
- logger .debug (
260
+ logger .warning (
190
261
"Cannot read broker message %s" ,
191
262
exc ,
192
263
exc_info = True ,
0 commit comments