1
- from asyncio import AbstractEventLoop
1
+ import asyncio
2
2
from logging import getLogger
3
- from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
3
+ from typing import Any , Callable , Coroutine , Optional , TypeVar
4
4
5
5
from aio_pika import ExchangeType , Message , connect_robust
6
- from aio_pika .abc import AbstractChannel , AbstractRobustConnection
6
+ from aio_pika .abc import (
7
+ AbstractChannel ,
8
+ AbstractIncomingMessage ,
9
+ AbstractRobustConnection ,
10
+ )
7
11
from taskiq .abc .broker import AsyncBroker
8
12
from taskiq .abc .result_backend import AsyncResultBackend
9
13
from taskiq .message import BrokerMessage
@@ -22,9 +26,7 @@ def __init__( # noqa: WPS211
22
26
result_backend : Optional [AsyncResultBackend [_T ]] = None ,
23
27
task_id_generator : Optional [Callable [[], str ]] = None ,
24
28
qos : int = 10 ,
25
- loop : Optional [AbstractEventLoop ] = None ,
26
- max_channel_pool_size : int = 2 ,
27
- max_connection_pool_size : int = 10 ,
29
+ loop : Optional [asyncio .AbstractEventLoop ] = None ,
28
30
exchange_name : str = "taskiq" ,
29
31
queue_name : str = "taskiq" ,
30
32
declare_exchange : bool = True ,
@@ -42,8 +44,6 @@ def __init__( # noqa: WPS211
42
44
:param task_id_generator: custom task_id genertaor.
43
45
:param qos: number of messages that worker can prefetch.
44
46
:param loop: specific even loop.
45
- :param max_channel_pool_size: maximum number of channels for each connection.
46
- :param max_connection_pool_size: maximum number of connections in pool.
47
47
:param exchange_name: name of exchange that used to send messages.
48
48
:param queue_name: queue that used to get incoming messages.
49
49
:param declare_exchange: whether you want to declare new exchange
@@ -112,9 +112,11 @@ async def kick(self, message: BrokerMessage) -> None:
112
112
as the task_name.
113
113
114
114
115
- :raises ValueError: if startup wasn't awaited .
115
+ :raises ValueError: if startup wasn't called .
116
116
:param message: message to send.
117
117
"""
118
+ if self .write_channel is None :
119
+ raise ValueError ("Please run startup before kicking." )
118
120
rmq_msg = Message (
119
121
body = message .message .encode (),
120
122
headers = {
@@ -123,44 +125,36 @@ async def kick(self, message: BrokerMessage) -> None:
123
125
** message .labels ,
124
126
},
125
127
)
126
- if self .write_channel is None :
127
- raise ValueError ("Please run startup before kicking." )
128
128
exchange = await self .write_channel .get_exchange (
129
129
self ._exchange_name ,
130
130
ensure = False ,
131
131
)
132
132
await exchange .publish (rmq_msg , routing_key = message .task_name )
133
133
134
- async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
134
+ async def listen (
135
+ self ,
136
+ callback : Callable [[BrokerMessage ], Coroutine [Any , Any , None ]],
137
+ ) -> None :
135
138
"""
136
139
Listen to queue.
137
140
138
- This function listens to queue and yields
139
- new messages .
141
+ This function listens to queue and calls
142
+ callback on every new message .
140
143
144
+ :param callback: function to call on new message.
141
145
:raises ValueError: if startup wasn't called.
142
- :yield: parsed broker messages.
143
146
"""
147
+ self .callback = callback
144
148
if self .read_channel is None :
145
149
raise ValueError ("Call startup before starting listening." )
146
- await self .read_channel .set_qos (prefetch_count = 0 )
150
+ await self .read_channel .set_qos (prefetch_count = self . _qos )
147
151
queue = await self .read_channel .get_queue (self ._queue_name , ensure = False )
148
- async with queue .iterator () as queue_iter :
149
- async for rmq_message in queue_iter :
150
- async with rmq_message .process ():
151
- try :
152
- yield BrokerMessage (
153
- task_id = rmq_message .headers .pop ("task_id" ),
154
- task_name = rmq_message .headers .pop ("task_name" ),
155
- message = rmq_message .body ,
156
- labels = rmq_message .headers ,
157
- )
158
- except (ValueError , LookupError ) as exc :
159
- logger .debug (
160
- "Cannot read broker message %s" ,
161
- exc ,
162
- exc_info = True ,
163
- )
152
+ await queue .consume (self .process_message )
153
+ try : # noqa: WPS501
154
+ # Wait until terminate
155
+ await asyncio .Future ()
156
+ finally :
157
+ await self .shutdown ()
164
158
165
159
async def shutdown (self ) -> None :
166
160
"""Close all connections on shutdown."""
@@ -173,3 +167,29 @@ async def shutdown(self) -> None:
173
167
await self .write_conn .close ()
174
168
if self .read_conn :
175
169
await self .read_conn .close ()
170
+
171
+ async def process_message (self , message : AbstractIncomingMessage ) -> None :
172
+ """
173
+ Process received message.
174
+
175
+ This function parses broker message and
176
+ calls callback.
177
+
178
+ :param message: received message.
179
+ """
180
+ async with message .process ():
181
+ try :
182
+ broker_message = BrokerMessage (
183
+ task_id = message .headers .pop ("task_id" ),
184
+ task_name = message .headers .pop ("task_name" ),
185
+ message = message .body ,
186
+ labels = message .headers ,
187
+ )
188
+ except (ValueError , LookupError ) as exc :
189
+ logger .debug (
190
+ "Cannot read broker message %s" ,
191
+ exc ,
192
+ exc_info = True ,
193
+ )
194
+ return
195
+ await self .callback (broker_message )
0 commit comments