1- from asyncio import AbstractEventLoop
1+ import asyncio
22from logging import getLogger
3- from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
3+ from typing import Any , Callable , Coroutine , Optional , TypeVar
44
55from aio_pika import ExchangeType , Message , connect_robust
6- from aio_pika .abc import AbstractChannel , AbstractRobustConnection
6+ from aio_pika .abc import (
7+ AbstractChannel ,
8+ AbstractIncomingMessage ,
9+ AbstractRobustConnection ,
10+ )
711from taskiq .abc .broker import AsyncBroker
812from taskiq .abc .result_backend import AsyncResultBackend
913from taskiq .message import BrokerMessage
@@ -22,9 +26,7 @@ def __init__( # noqa: WPS211
2226 result_backend : Optional [AsyncResultBackend [_T ]] = None ,
2327 task_id_generator : Optional [Callable [[], str ]] = None ,
2428 qos : int = 10 ,
25- loop : Optional [AbstractEventLoop ] = None ,
26- max_channel_pool_size : int = 2 ,
27- max_connection_pool_size : int = 10 ,
29+ loop : Optional [asyncio .AbstractEventLoop ] = None ,
2830 exchange_name : str = "taskiq" ,
2931 queue_name : str = "taskiq" ,
3032 declare_exchange : bool = True ,
@@ -42,8 +44,6 @@ def __init__( # noqa: WPS211
4244 :param task_id_generator: custom task_id genertaor.
4345 :param qos: number of messages that worker can prefetch.
4446 :param loop: specific even loop.
45- :param max_channel_pool_size: maximum number of channels for each connection.
46- :param max_connection_pool_size: maximum number of connections in pool.
4747 :param exchange_name: name of exchange that used to send messages.
4848 :param queue_name: queue that used to get incoming messages.
4949 :param declare_exchange: whether you want to declare new exchange
@@ -112,9 +112,11 @@ async def kick(self, message: BrokerMessage) -> None:
112112 as the task_name.
113113
114114
115- :raises ValueError: if startup wasn't awaited .
115+ :raises ValueError: if startup wasn't called .
116116 :param message: message to send.
117117 """
118+ if self .write_channel is None :
119+ raise ValueError ("Please run startup before kicking." )
118120 rmq_msg = Message (
119121 body = message .message .encode (),
120122 headers = {
@@ -123,44 +125,36 @@ async def kick(self, message: BrokerMessage) -> None:
123125 ** message .labels ,
124126 },
125127 )
126- if self .write_channel is None :
127- raise ValueError ("Please run startup before kicking." )
128128 exchange = await self .write_channel .get_exchange (
129129 self ._exchange_name ,
130130 ensure = False ,
131131 )
132132 await exchange .publish (rmq_msg , routing_key = message .task_name )
133133
134- async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
134+ async def listen (
135+ self ,
136+ callback : Callable [[BrokerMessage ], Coroutine [Any , Any , None ]],
137+ ) -> None :
135138 """
136139 Listen to queue.
137140
138- This function listens to queue and yields
139- new messages .
141+ This function listens to queue and calls
142+ callback on every new message .
140143
144+ :param callback: function to call on new message.
141145 :raises ValueError: if startup wasn't called.
142- :yield: parsed broker messages.
143146 """
147+ self .callback = callback
144148 if self .read_channel is None :
145149 raise ValueError ("Call startup before starting listening." )
146- await self .read_channel .set_qos (prefetch_count = 0 )
150+ await self .read_channel .set_qos (prefetch_count = self . _qos )
147151 queue = await self .read_channel .get_queue (self ._queue_name , ensure = False )
148- async with queue .iterator () as queue_iter :
149- async for rmq_message in queue_iter :
150- async with rmq_message .process ():
151- try :
152- yield BrokerMessage (
153- task_id = rmq_message .headers .pop ("task_id" ),
154- task_name = rmq_message .headers .pop ("task_name" ),
155- message = rmq_message .body ,
156- labels = rmq_message .headers ,
157- )
158- except (ValueError , LookupError ) as exc :
159- logger .debug (
160- "Cannot read broker message %s" ,
161- exc ,
162- exc_info = True ,
163- )
152+ await queue .consume (self .process_message )
153+ try : # noqa: WPS501
154+ # Wait until terminate
155+ await asyncio .Future ()
156+ finally :
157+ await self .shutdown ()
164158
165159 async def shutdown (self ) -> None :
166160 """Close all connections on shutdown."""
@@ -173,3 +167,29 @@ async def shutdown(self) -> None:
173167 await self .write_conn .close ()
174168 if self .read_conn :
175169 await self .read_conn .close ()
170+
171+ async def process_message (self , message : AbstractIncomingMessage ) -> None :
172+ """
173+ Process received message.
174+
175+ This function parses broker message and
176+ calls callback.
177+
178+ :param message: received message.
179+ """
180+ async with message .process ():
181+ try :
182+ broker_message = BrokerMessage (
183+ task_id = message .headers .pop ("task_id" ),
184+ task_name = message .headers .pop ("task_name" ),
185+ message = message .body ,
186+ labels = message .headers ,
187+ )
188+ except (ValueError , LookupError ) as exc :
189+ logger .debug (
190+ "Cannot read broker message %s" ,
191+ exc ,
192+ exc_info = True ,
193+ )
194+ return
195+ await self .callback (broker_message )
0 commit comments