Skip to content

Commit 49e9a59

Browse files
committed
Updated broker, so it uses different channels for reading and writing.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 9c03ef6 commit 49e9a59

File tree

1 file changed

+74
-59
lines changed

1 file changed

+74
-59
lines changed

taskiq_aio_pika/broker.py

Lines changed: 74 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from logging import getLogger
33
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
44

5-
from aio_pika import Channel, ExchangeType, Message, connect_robust
5+
from aio_pika import ExchangeType, Message, connect_robust
66
from aio_pika.abc import AbstractChannel, AbstractRobustConnection
7-
from aio_pika.pool import Pool
87
from taskiq.abc.broker import AsyncBroker
98
from taskiq.abc.result_backend import AsyncResultBackend
109
from taskiq.message import BrokerMessage
@@ -57,48 +56,49 @@ def __init__( # noqa: WPS211
5756
"""
5857
super().__init__(result_backend, task_id_generator)
5958

60-
async def _get_rmq_connection() -> AbstractRobustConnection:
61-
return await connect_robust(
62-
url,
63-
loop=loop,
64-
**connection_kwargs,
65-
)
66-
67-
self._connection_pool: Pool[AbstractRobustConnection] = Pool(
68-
_get_rmq_connection,
69-
max_size=max_connection_pool_size,
70-
loop=loop,
71-
)
72-
73-
async def get_channel() -> AbstractChannel:
74-
async with self._connection_pool.acquire() as connection:
75-
return await connection.channel()
76-
77-
self._channel_pool: Pool[Channel] = Pool(
78-
get_channel,
79-
max_size=max_channel_pool_size,
80-
loop=loop,
81-
)
82-
59+
self.url = url
60+
self._loop = loop
61+
self._conn_kwargs = connection_kwargs
8362
self._exchange_name = exchange_name
8463
self._exchange_type = exchange_type
8564
self._qos = qos
8665
self._declare_exchange = declare_exchange
8766
self._queue_name = queue_name
8867
self._routing_key = routing_key
68+
self.read_conn: Optional[AbstractRobustConnection] = None
69+
self.write_conn: Optional[AbstractRobustConnection] = None
70+
self.write_channel: Optional[AbstractChannel] = None
71+
self.read_channel: Optional[AbstractChannel] = None
8972

90-
async def startup(self) -> None:
73+
async def startup(self) -> None: # noqa: WPS217
9174
"""Create exchange and queue on startup."""
92-
async with self._channel_pool.acquire() as channel:
93-
if self._declare_exchange:
94-
exchange = await channel.declare_exchange(
95-
self._exchange_name,
96-
type=self._exchange_type,
97-
)
98-
else:
99-
exchange = await channel.get_exchange(self._exchange_name, ensure=False)
100-
queue = await channel.declare_queue(self._queue_name)
101-
await queue.bind(exchange=exchange, routing_key=self._routing_key)
75+
self.write_conn = await connect_robust(
76+
self.url,
77+
loop=self._loop,
78+
**self._conn_kwargs,
79+
)
80+
self.write_channel = await self.write_conn.channel()
81+
82+
if self.is_worker_process:
83+
self.read_conn = await connect_robust(
84+
self.url,
85+
loop=self._loop,
86+
**self._conn_kwargs,
87+
)
88+
self.read_channel = await self.read_conn.channel()
89+
90+
if self._declare_exchange:
91+
exchange = await self.write_channel.declare_exchange(
92+
self._exchange_name,
93+
type=self._exchange_type,
94+
)
95+
else:
96+
exchange = await self.write_channel.get_exchange(
97+
self._exchange_name,
98+
ensure=False,
99+
)
100+
queue = await self.write_channel.declare_queue(self._queue_name)
101+
await queue.bind(exchange=exchange, routing_key=self._routing_key)
102102

103103
async def kick(self, message: BrokerMessage) -> None:
104104
"""
@@ -111,6 +111,8 @@ async def kick(self, message: BrokerMessage) -> None:
111111
in headers. And message's routing key is the same
112112
as the task_name.
113113
114+
115+
:raises ValueError: if startup wasn't awaited.
114116
:param message: message to send.
115117
"""
116118
rmq_msg = Message(
@@ -121,9 +123,13 @@ async def kick(self, message: BrokerMessage) -> None:
121123
**message.labels,
122124
},
123125
)
124-
async with self._channel_pool.acquire() as channel:
125-
exchange = await channel.get_exchange(self._exchange_name, ensure=False)
126-
await exchange.publish(rmq_msg, routing_key=message.task_name)
126+
if self.write_channel is None:
127+
raise ValueError("Please run startup before kicking.")
128+
exchange = await self.write_channel.get_exchange(
129+
self._exchange_name,
130+
ensure=False,
131+
)
132+
await exchange.publish(rmq_msg, routing_key=message.task_name)
127133

128134
async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
129135
"""
@@ -132,29 +138,38 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
132138
This function listens to queue and yields
133139
new messages.
134140
141+
:raises ValueError: if startup wasn't called.
135142
:yield: parsed broker messages.
136143
"""
137-
async with self._channel_pool.acquire() as channel:
138-
await channel.set_qos(prefetch_count=self._qos)
139-
queue = await channel.get_queue(self._queue_name, ensure=False)
140-
async with queue.iterator() as queue_iter:
141-
async for rmq_message in queue_iter:
142-
async with rmq_message.process():
143-
try:
144-
yield BrokerMessage(
145-
task_id=rmq_message.headers.pop("task_id"),
146-
task_name=rmq_message.headers.pop("task_name"),
147-
message=rmq_message.body,
148-
labels=rmq_message.headers,
149-
)
150-
except (ValueError, LookupError) as exc:
151-
logger.debug(
152-
"Cannot read broker message %s",
153-
exc,
154-
exc_info=True,
155-
)
144+
if self.read_channel is None:
145+
raise ValueError("Call startup before starting listening.")
146+
await self.read_channel.set_qos(prefetch_count=0)
147+
queue = await self.read_channel.get_queue(self._queue_name, ensure=False)
148+
async with queue.iterator() as queue_iter:
149+
async for rmq_message in queue_iter:
150+
async with rmq_message.process():
151+
try:
152+
yield BrokerMessage(
153+
task_id=rmq_message.headers.pop("task_id"),
154+
task_name=rmq_message.headers.pop("task_name"),
155+
message=rmq_message.body,
156+
labels=rmq_message.headers,
157+
)
158+
except (ValueError, LookupError) as exc:
159+
logger.debug(
160+
"Cannot read broker message %s",
161+
exc,
162+
exc_info=True,
163+
)
156164

157165
async def shutdown(self) -> None:
158166
"""Close all connections on shutdown."""
159167
await super().shutdown()
160-
await self._connection_pool.close()
168+
if self.write_channel:
169+
await self.write_channel.close()
170+
if self.read_channel:
171+
await self.read_channel.close()
172+
if self.write_conn:
173+
await self.write_conn.close()
174+
if self.read_conn:
175+
await self.read_conn.close()

0 commit comments

Comments
 (0)