11import asyncio
2+ from datetime import timedelta
23from logging import getLogger
3- from typing import Any , Callable , Coroutine , Optional , TypeVar
4+ from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar
45
5- from aio_pika import ExchangeType , Message , connect_robust
6+ from aio_pika import DeliveryMode , ExchangeType , Message , connect_robust
67from aio_pika .abc import (
78 AbstractChannel ,
89 AbstractIncomingMessage ,
1718logger = getLogger ("taskiq.aio_pika_broker" )
1819
1920
21+ def parse_val (
22+ parse_func : Callable [[str ], _T ],
23+ target : Optional [str ] = None ,
24+ ) -> Optional [_T ]:
25+ """
26+ Parse string to some value.
27+
28+ :param parse_func: function to use if value is present.
29+ :param target: value to parse, defaults to None
30+ :return: Optional value.
31+ """
32+ if target is None :
33+ return None
34+
35+ try :
36+ return parse_func (target )
37+ except ValueError :
38+ return None
39+
40+
2041class AioPikaBroker (AsyncBroker ):
2142 """Broker that works with RabbitMQ."""
2243
@@ -29,9 +50,12 @@ def __init__( # noqa: WPS211
2950 loop : Optional [asyncio .AbstractEventLoop ] = None ,
3051 exchange_name : str = "taskiq" ,
3152 queue_name : str = "taskiq" ,
53+ dead_letter_queue_name : Optional [str ] = None ,
54+ delay_queue_name : Optional [str ] = None ,
3255 declare_exchange : bool = True ,
3356 routing_key : str = "#" ,
3457 exchange_type : ExchangeType = ExchangeType .TOPIC ,
58+ max_priority : Optional [int ] = None ,
3559 ** connection_kwargs : Any ,
3660 ) -> None :
3761 """
@@ -46,11 +70,16 @@ def __init__( # noqa: WPS211
4670 :param loop: specific even loop.
4771 :param exchange_name: name of exchange that used to send messages.
4872 :param queue_name: queue that used to get incoming messages.
73+ :param dead_letter_queue_name: custom name for dead-letter queue.
74+ by default it set to {queue_name}.dead_letter.
75+ :param delay_queue_name: custom name for queue that used to
76+ deliver messages with delays.
4977 :param declare_exchange: whether you want to declare new exchange
5078 if it doesn't exist.
5179 :param routing_key: that used to bind that queue to the exchange.
5280 :param exchange_type: type of the exchange.
5381 Used only if `declare_exchange` is True.
82+ :param max_priority: maximum priority value for messages.
5483 :param connection_kwargs: additional keyword arguments,
5584 for connect_robust method of aio-pika.
5685 """
@@ -65,6 +94,16 @@ def __init__( # noqa: WPS211
6594 self ._declare_exchange = declare_exchange
6695 self ._queue_name = queue_name
6796 self ._routing_key = routing_key
97+ self ._max_priority = max_priority
98+
99+ self ._dead_letter_queue_name = f"{ queue_name } .dead_letter"
100+ if dead_letter_queue_name :
101+ self ._dead_letter_queue_name = dead_letter_queue_name
102+
103+ self ._delay_queue_name = f"{ queue_name } .delay"
104+ if delay_queue_name :
105+ self ._delay_queue_name = delay_queue_name
106+
68107 self .read_conn : Optional [AbstractRobustConnection ] = None
69108 self .write_conn : Optional [AbstractRobustConnection ] = None
70109 self .write_channel : Optional [AbstractChannel ] = None
@@ -97,7 +136,26 @@ async def startup(self) -> None: # noqa: WPS217
97136 self ._exchange_name ,
98137 ensure = False ,
99138 )
100- queue = await self .write_channel .declare_queue (self ._queue_name )
139+ await self .write_channel .declare_queue (
140+ self ._dead_letter_queue_name ,
141+ )
142+ args : "Dict[str, Any]" = {
143+ "x-dead-letter-exchange" : "" ,
144+ "x-dead-letter-routing-key" : self ._dead_letter_queue_name ,
145+ }
146+ if self ._max_priority is not None :
147+ args ["x-max-priority" ] = self ._max_priority
148+ queue = await self .write_channel .declare_queue (
149+ self ._queue_name ,
150+ arguments = args ,
151+ )
152+ await self .write_channel .declare_queue (
153+ self ._delay_queue_name ,
154+ arguments = {
155+ "x-dead-letter-exchange" : "" ,
156+ "x-dead-letter-routing-key" : self ._queue_name ,
157+ },
158+ )
101159 await queue .bind (exchange = exchange , routing_key = self ._routing_key )
102160
103161 async def kick (self , message : BrokerMessage ) -> None :
@@ -111,25 +169,35 @@ async def kick(self, message: BrokerMessage) -> None:
111169 in headers. And message's routing key is the same
112170 as the task_name.
113171
114-
115172 :raises ValueError: if startup wasn't called.
116173 :param message: message to send.
117174 """
118175 if self .write_channel is None :
119176 raise ValueError ("Please run startup before kicking." )
177+ priority = parse_val (int , message .labels .get ("priority" ))
120178 rmq_msg = Message (
121179 body = message .message .encode (),
122180 headers = {
123181 "task_id" : message .task_id ,
124182 "task_name" : message .task_name ,
125183 ** message .labels ,
126184 },
185+ delivery_mode = DeliveryMode .PERSISTENT ,
186+ priority = priority ,
127187 )
128- exchange = await self .write_channel .get_exchange (
129- self ._exchange_name ,
130- ensure = False ,
131- )
132- await exchange .publish (rmq_msg , routing_key = message .task_name )
188+ delay = parse_val (int , message .labels .get ("delay" ))
189+ if delay is None :
190+ exchange = await self .write_channel .get_exchange (
191+ self ._exchange_name ,
192+ ensure = False ,
193+ )
194+ await exchange .publish (rmq_msg , routing_key = message .task_name )
195+ else :
196+ rmq_msg .expiration = timedelta (seconds = delay )
197+ await self .write_channel .default_exchange .publish (
198+ rmq_msg ,
199+ routing_key = self ._delay_queue_name ,
200+ )
133201
134202 async def listen (
135203 self ,
@@ -178,15 +246,18 @@ async def process_message(self, message: AbstractIncomingMessage) -> None:
178246 :param message: received message.
179247 """
180248 async with message .process ():
249+ headers = {}
250+ for header_name , header_value in message .headers .items ():
251+ headers [header_name ] = str (header_value )
181252 try :
182253 broker_message = BrokerMessage (
183- task_id = message . headers .pop ("task_id" ),
184- task_name = message . headers .pop ("task_name" ),
254+ task_id = headers .pop ("task_id" ),
255+ task_name = headers .pop ("task_name" ),
185256 message = message .body ,
186- labels = message . headers ,
257+ labels = headers ,
187258 )
188259 except (ValueError , LookupError ) as exc :
189- logger .debug (
260+ logger .warning (
190261 "Cannot read broker message %s" ,
191262 exc ,
192263 exc_info = True ,
0 commit comments