1
1
import asyncio
2
2
from datetime import timedelta
3
3
from logging import getLogger
4
- from typing import Any , Callable , Coroutine , Dict , Optional , TypeVar
4
+ from typing import Any , AsyncGenerator , Callable , Dict , Optional , TypeVar
5
5
6
6
from aio_pika import DeliveryMode , ExchangeType , Message , connect_robust
7
- from aio_pika .abc import (
8
- AbstractChannel ,
9
- AbstractIncomingMessage ,
10
- AbstractQueue ,
11
- AbstractRobustConnection ,
12
- )
7
+ from aio_pika .abc import AbstractChannel , AbstractQueue , AbstractRobustConnection
13
8
from taskiq .abc .broker import AsyncBroker
14
9
from taskiq .abc .result_backend import AsyncResultBackend
15
10
from taskiq .message import BrokerMessage
@@ -219,30 +214,42 @@ async def kick(self, message: BrokerMessage) -> None:
219
214
routing_key = self ._delay_queue_name ,
220
215
)
221
216
222
- async def listen (
223
- self ,
224
- callback : Callable [[BrokerMessage ], Coroutine [Any , Any , None ]],
225
- ) -> None :
217
+ async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]: # noqa: WPS210
226
218
"""
227
219
Listen to queue.
228
220
229
- This function listens to queue and calls
230
- callback on every new message.
221
+ This function listens to queue and
222
+ yields every new message.
231
223
232
- :param callback: function to call on new message.
224
+ :yields: parsed broker message.
233
225
:raises ValueError: if startup wasn't called.
234
226
"""
235
- self .callback = callback
236
227
if self .read_channel is None :
237
228
raise ValueError ("Call startup before starting listening." )
238
229
await self .read_channel .set_qos (prefetch_count = self ._qos )
239
230
queue = await self .declare_queues (self .read_channel )
240
- await queue .consume (self .process_message )
241
- try : # noqa: WPS501
242
- # Wait until terminate
243
- await asyncio .Future ()
244
- finally :
245
- await self .shutdown ()
231
+ async with queue .iterator () as iterator :
232
+ async for message in iterator :
233
+ async with message .process ():
234
+ headers = {}
235
+ for header_name , header_value in message .headers .items ():
236
+ headers [header_name ] = str (header_value )
237
+ try :
238
+ broker_message = BrokerMessage (
239
+ task_id = headers .pop ("task_id" ),
240
+ task_name = headers .pop ("task_name" ),
241
+ message = message .body ,
242
+ labels = headers ,
243
+ )
244
+ except (ValueError , LookupError ) as exc :
245
+ logger .warning (
246
+ "Cannot read broker message %s" ,
247
+ exc ,
248
+ exc_info = True ,
249
+ )
250
+ continue
251
+
252
+ yield broker_message
246
253
247
254
async def shutdown (self ) -> None :
248
255
"""Close all connections on shutdown."""
@@ -255,32 +262,3 @@ async def shutdown(self) -> None:
255
262
await self .write_conn .close ()
256
263
if self .read_conn :
257
264
await self .read_conn .close ()
258
-
259
- async def process_message (self , message : AbstractIncomingMessage ) -> None :
260
- """
261
- Process received message.
262
-
263
- This function parses broker message and
264
- calls callback.
265
-
266
- :param message: received message.
267
- """
268
- async with message .process ():
269
- headers = {}
270
- for header_name , header_value in message .headers .items ():
271
- headers [header_name ] = str (header_value )
272
- try :
273
- broker_message = BrokerMessage (
274
- task_id = headers .pop ("task_id" ),
275
- task_name = headers .pop ("task_name" ),
276
- message = message .body ,
277
- labels = headers ,
278
- )
279
- except (ValueError , LookupError ) as exc :
280
- logger .warning (
281
- "Cannot read broker message %s" ,
282
- exc ,
283
- exc_info = True ,
284
- )
285
- return
286
- await self .callback (broker_message )
0 commit comments