2
2
from logging import getLogger
3
3
from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
4
4
5
- from aio_pika import Channel , ExchangeType , Message , connect_robust
5
+ from aio_pika import ExchangeType , Message , connect_robust
6
6
from aio_pika .abc import AbstractChannel , AbstractRobustConnection
7
- from aio_pika .pool import Pool
8
7
from taskiq .abc .broker import AsyncBroker
9
8
from taskiq .abc .result_backend import AsyncResultBackend
10
9
from taskiq .message import BrokerMessage
@@ -57,48 +56,49 @@ def __init__( # noqa: WPS211
57
56
"""
58
57
super ().__init__ (result_backend , task_id_generator )
59
58
60
- async def _get_rmq_connection () -> AbstractRobustConnection :
61
- return await connect_robust (
62
- url ,
63
- loop = loop ,
64
- ** connection_kwargs ,
65
- )
66
-
67
- self ._connection_pool : Pool [AbstractRobustConnection ] = Pool (
68
- _get_rmq_connection ,
69
- max_size = max_connection_pool_size ,
70
- loop = loop ,
71
- )
72
-
73
- async def get_channel () -> AbstractChannel :
74
- async with self ._connection_pool .acquire () as connection :
75
- return await connection .channel ()
76
-
77
- self ._channel_pool : Pool [Channel ] = Pool (
78
- get_channel ,
79
- max_size = max_channel_pool_size ,
80
- loop = loop ,
81
- )
82
-
59
+ self .url = url
60
+ self ._loop = loop
61
+ self ._conn_kwargs = connection_kwargs
83
62
self ._exchange_name = exchange_name
84
63
self ._exchange_type = exchange_type
85
64
self ._qos = qos
86
65
self ._declare_exchange = declare_exchange
87
66
self ._queue_name = queue_name
88
67
self ._routing_key = routing_key
68
+ self .read_conn : Optional [AbstractRobustConnection ] = None
69
+ self .write_conn : Optional [AbstractRobustConnection ] = None
70
+ self .write_channel : Optional [AbstractChannel ] = None
71
+ self .read_channel : Optional [AbstractChannel ] = None
89
72
90
- async def startup (self ) -> None :
73
+ async def startup (self ) -> None : # noqa: WPS217
91
74
"""Create exchange and queue on startup."""
92
- async with self ._channel_pool .acquire () as channel :
93
- if self ._declare_exchange :
94
- exchange = await channel .declare_exchange (
95
- self ._exchange_name ,
96
- type = self ._exchange_type ,
97
- )
98
- else :
99
- exchange = await channel .get_exchange (self ._exchange_name , ensure = False )
100
- queue = await channel .declare_queue (self ._queue_name )
101
- await queue .bind (exchange = exchange , routing_key = self ._routing_key )
75
+ self .write_conn = await connect_robust (
76
+ self .url ,
77
+ loop = self ._loop ,
78
+ ** self ._conn_kwargs ,
79
+ )
80
+ self .write_channel = await self .write_conn .channel ()
81
+
82
+ if self .is_worker_process :
83
+ self .read_conn = await connect_robust (
84
+ self .url ,
85
+ loop = self ._loop ,
86
+ ** self ._conn_kwargs ,
87
+ )
88
+ self .read_channel = await self .read_conn .channel ()
89
+
90
+ if self ._declare_exchange :
91
+ exchange = await self .write_channel .declare_exchange (
92
+ self ._exchange_name ,
93
+ type = self ._exchange_type ,
94
+ )
95
+ else :
96
+ exchange = await self .write_channel .get_exchange (
97
+ self ._exchange_name ,
98
+ ensure = False ,
99
+ )
100
+ queue = await self .write_channel .declare_queue (self ._queue_name )
101
+ await queue .bind (exchange = exchange , routing_key = self ._routing_key )
102
102
103
103
async def kick (self , message : BrokerMessage ) -> None :
104
104
"""
@@ -111,6 +111,8 @@ async def kick(self, message: BrokerMessage) -> None:
111
111
in headers. And message's routing key is the same
112
112
as the task_name.
113
113
114
+
115
+ :raises ValueError: if startup wasn't awaited.
114
116
:param message: message to send.
115
117
"""
116
118
rmq_msg = Message (
@@ -121,9 +123,13 @@ async def kick(self, message: BrokerMessage) -> None:
121
123
** message .labels ,
122
124
},
123
125
)
124
- async with self ._channel_pool .acquire () as channel :
125
- exchange = await channel .get_exchange (self ._exchange_name , ensure = False )
126
- await exchange .publish (rmq_msg , routing_key = message .task_name )
126
+ if self .write_channel is None :
127
+ raise ValueError ("Please run startup before kicking." )
128
+ exchange = await self .write_channel .get_exchange (
129
+ self ._exchange_name ,
130
+ ensure = False ,
131
+ )
132
+ await exchange .publish (rmq_msg , routing_key = message .task_name )
127
133
128
134
async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
129
135
"""
@@ -132,29 +138,38 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
132
138
This function listens to queue and yields
133
139
new messages.
134
140
141
+ :raises ValueError: if startup wasn't called.
135
142
:yield: parsed broker messages.
136
143
"""
137
- async with self ._channel_pool .acquire () as channel :
138
- await channel .set_qos (prefetch_count = self ._qos )
139
- queue = await channel .get_queue (self ._queue_name , ensure = False )
140
- async with queue .iterator () as queue_iter :
141
- async for rmq_message in queue_iter :
142
- async with rmq_message .process ():
143
- try :
144
- yield BrokerMessage (
145
- task_id = rmq_message .headers .pop ("task_id" ),
146
- task_name = rmq_message .headers .pop ("task_name" ),
147
- message = rmq_message .body ,
148
- labels = rmq_message .headers ,
149
- )
150
- except (ValueError , LookupError ) as exc :
151
- logger .debug (
152
- "Cannot read broker message %s" ,
153
- exc ,
154
- exc_info = True ,
155
- )
144
+ if self .read_channel is None :
145
+ raise ValueError ("Call startup before starting listening." )
146
+ await self .read_channel .set_qos (prefetch_count = 0 )
147
+ queue = await self .read_channel .get_queue (self ._queue_name , ensure = False )
148
+ async with queue .iterator () as queue_iter :
149
+ async for rmq_message in queue_iter :
150
+ async with rmq_message .process ():
151
+ try :
152
+ yield BrokerMessage (
153
+ task_id = rmq_message .headers .pop ("task_id" ),
154
+ task_name = rmq_message .headers .pop ("task_name" ),
155
+ message = rmq_message .body ,
156
+ labels = rmq_message .headers ,
157
+ )
158
+ except (ValueError , LookupError ) as exc :
159
+ logger .debug (
160
+ "Cannot read broker message %s" ,
161
+ exc ,
162
+ exc_info = True ,
163
+ )
156
164
157
165
async def shutdown (self ) -> None :
158
166
"""Close all connections on shutdown."""
159
167
await super ().shutdown ()
160
- await self ._connection_pool .close ()
168
+ if self .write_channel :
169
+ await self .write_channel .close ()
170
+ if self .read_channel :
171
+ await self .read_channel .close ()
172
+ if self .write_conn :
173
+ await self .write_conn .close ()
174
+ if self .read_conn :
175
+ await self .read_conn .close ()
0 commit comments