22from logging import getLogger
33from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
44
5- from aio_pika import Channel , ExchangeType , Message , connect_robust
5+ from aio_pika import ExchangeType , Message , connect_robust
66from aio_pika .abc import AbstractChannel , AbstractRobustConnection
7- from aio_pika .pool import Pool
87from taskiq .abc .broker import AsyncBroker
98from taskiq .abc .result_backend import AsyncResultBackend
109from taskiq .message import BrokerMessage
@@ -57,48 +56,49 @@ def __init__( # noqa: WPS211
5756 """
5857 super ().__init__ (result_backend , task_id_generator )
5958
60- async def _get_rmq_connection () -> AbstractRobustConnection :
61- return await connect_robust (
62- url ,
63- loop = loop ,
64- ** connection_kwargs ,
65- )
66-
67- self ._connection_pool : Pool [AbstractRobustConnection ] = Pool (
68- _get_rmq_connection ,
69- max_size = max_connection_pool_size ,
70- loop = loop ,
71- )
72-
73- async def get_channel () -> AbstractChannel :
74- async with self ._connection_pool .acquire () as connection :
75- return await connection .channel ()
76-
77- self ._channel_pool : Pool [Channel ] = Pool (
78- get_channel ,
79- max_size = max_channel_pool_size ,
80- loop = loop ,
81- )
82-
59+ self .url = url
60+ self ._loop = loop
61+ self ._conn_kwargs = connection_kwargs
8362 self ._exchange_name = exchange_name
8463 self ._exchange_type = exchange_type
8564 self ._qos = qos
8665 self ._declare_exchange = declare_exchange
8766 self ._queue_name = queue_name
8867 self ._routing_key = routing_key
68+ self .read_conn : Optional [AbstractRobustConnection ] = None
69+ self .write_conn : Optional [AbstractRobustConnection ] = None
70+ self .write_channel : Optional [AbstractChannel ] = None
71+ self .read_channel : Optional [AbstractChannel ] = None
8972
90- async def startup (self ) -> None :
73+ async def startup (self ) -> None : # noqa: WPS217
9174 """Create exchange and queue on startup."""
92- async with self ._channel_pool .acquire () as channel :
93- if self ._declare_exchange :
94- exchange = await channel .declare_exchange (
95- self ._exchange_name ,
96- type = self ._exchange_type ,
97- )
98- else :
99- exchange = await channel .get_exchange (self ._exchange_name , ensure = False )
100- queue = await channel .declare_queue (self ._queue_name )
101- await queue .bind (exchange = exchange , routing_key = self ._routing_key )
75+ self .write_conn = await connect_robust (
76+ self .url ,
77+ loop = self ._loop ,
78+ ** self ._conn_kwargs ,
79+ )
80+ self .write_channel = await self .write_conn .channel ()
81+
82+ if self .is_worker_process :
83+ self .read_conn = await connect_robust (
84+ self .url ,
85+ loop = self ._loop ,
86+ ** self ._conn_kwargs ,
87+ )
88+ self .read_channel = await self .read_conn .channel ()
89+
90+ if self ._declare_exchange :
91+ exchange = await self .write_channel .declare_exchange (
92+ self ._exchange_name ,
93+ type = self ._exchange_type ,
94+ )
95+ else :
96+ exchange = await self .write_channel .get_exchange (
97+ self ._exchange_name ,
98+ ensure = False ,
99+ )
100+ queue = await self .write_channel .declare_queue (self ._queue_name )
101+ await queue .bind (exchange = exchange , routing_key = self ._routing_key )
102102
103103 async def kick (self , message : BrokerMessage ) -> None :
104104 """
@@ -111,6 +111,8 @@ async def kick(self, message: BrokerMessage) -> None:
111111 in headers. And message's routing key is the same
112112 as the task_name.
113113
114+
115+ :raises ValueError: if startup wasn't awaited.
114116 :param message: message to send.
115117 """
116118 rmq_msg = Message (
@@ -121,9 +123,13 @@ async def kick(self, message: BrokerMessage) -> None:
121123 ** message .labels ,
122124 },
123125 )
124- async with self ._channel_pool .acquire () as channel :
125- exchange = await channel .get_exchange (self ._exchange_name , ensure = False )
126- await exchange .publish (rmq_msg , routing_key = message .task_name )
126+ if self .write_channel is None :
127+ raise ValueError ("Please run startup before kicking." )
128+ exchange = await self .write_channel .get_exchange (
129+ self ._exchange_name ,
130+ ensure = False ,
131+ )
132+ await exchange .publish (rmq_msg , routing_key = message .task_name )
127133
128134 async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
129135 """
@@ -132,29 +138,38 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
132138 This function listens to queue and yields
133139 new messages.
134140
141+ :raises ValueError: if startup wasn't called.
135142 :yield: parsed broker messages.
136143 """
137- async with self ._channel_pool .acquire () as channel :
138- await channel .set_qos (prefetch_count = self ._qos )
139- queue = await channel .get_queue (self ._queue_name , ensure = False )
140- async with queue .iterator () as queue_iter :
141- async for rmq_message in queue_iter :
142- async with rmq_message .process ():
143- try :
144- yield BrokerMessage (
145- task_id = rmq_message .headers .pop ("task_id" ),
146- task_name = rmq_message .headers .pop ("task_name" ),
147- message = rmq_message .body ,
148- labels = rmq_message .headers ,
149- )
150- except (ValueError , LookupError ) as exc :
151- logger .debug (
152- "Cannot read broker message %s" ,
153- exc ,
154- exc_info = True ,
155- )
144+ if self .read_channel is None :
145+ raise ValueError ("Call startup before starting listening." )
146+ await self .read_channel .set_qos (prefetch_count = 0 )
147+ queue = await self .read_channel .get_queue (self ._queue_name , ensure = False )
148+ async with queue .iterator () as queue_iter :
149+ async for rmq_message in queue_iter :
150+ async with rmq_message .process ():
151+ try :
152+ yield BrokerMessage (
153+ task_id = rmq_message .headers .pop ("task_id" ),
154+ task_name = rmq_message .headers .pop ("task_name" ),
155+ message = rmq_message .body ,
156+ labels = rmq_message .headers ,
157+ )
158+ except (ValueError , LookupError ) as exc :
159+ logger .debug (
160+ "Cannot read broker message %s" ,
161+ exc ,
162+ exc_info = True ,
163+ )
156164
157165 async def shutdown (self ) -> None :
158166 """Close all connections on shutdown."""
159167 await super ().shutdown ()
160- await self ._connection_pool .close ()
168+ if self .write_channel :
169+ await self .write_channel .close ()
170+ if self .read_channel :
171+ await self .read_channel .close ()
172+ if self .write_conn :
173+ await self .write_conn .close ()
174+ if self .read_conn :
175+ await self .read_conn .close ()
0 commit comments