Skip to content

Commit 433f579

Browse files
committed
Merge branch 'release/0.0.5'
2 parents 1642fcf + 1258021 commit 433f579

File tree

6 files changed

+107
-50
lines changed

6 files changed

+107
-50
lines changed

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ ignore =
8181
per-file-ignores =
8282
; all tests
8383
test_*.py,tests.py,tests_*.py,*/tests/*:
84+
; Found magic number
85+
WPS432,
8486
; Use of assert detected
8587
S101,
8688
; Found outer scope names shadowing.

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,3 @@ AioPikaBroker parameters:
2727
* `routing_key` - that used to bind that queue to the exchange.
2828
* `declare_exchange` - whether you want to declare new exchange if it doesn't exist.
2929
* `qos` - number of messages that worker can prefetch.
30-
* `max_connection_pool_size` - maximum number of connections in pool.
31-
* `max_channel_pool_size` - maximum number of channels for each connection.

poetry.lock

Lines changed: 36 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "taskiq-aio-pika"
3-
version = "0.0.4"
3+
version = "0.0.5"
44
description = "RabbitMQ broker for taskiq"
55
authors = ["Pavel Kirilin <[email protected]>"]
66
readme = "README.md"
@@ -35,6 +35,8 @@ wemake-python-styleguide = "^0.16.1"
3535
pytest-xdist = { version = "^2.5.0", extras = ["psutil"] }
3636
anyio = "^3.6.1"
3737
pytest-cov = "^3.0.0"
38+
mock = "^4.0.3"
39+
types-mock = "^4.0.15"
3840

3941
[tool.mypy]
4042
strict = true

taskiq_aio_pika/broker.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from asyncio import AbstractEventLoop
1+
import asyncio
22
from logging import getLogger
3-
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
3+
from typing import Any, Callable, Coroutine, Optional, TypeVar
44

55
from aio_pika import ExchangeType, Message, connect_robust
6-
from aio_pika.abc import AbstractChannel, AbstractRobustConnection
6+
from aio_pika.abc import (
7+
AbstractChannel,
8+
AbstractIncomingMessage,
9+
AbstractRobustConnection,
10+
)
711
from taskiq.abc.broker import AsyncBroker
812
from taskiq.abc.result_backend import AsyncResultBackend
913
from taskiq.message import BrokerMessage
@@ -22,9 +26,7 @@ def __init__( # noqa: WPS211
2226
result_backend: Optional[AsyncResultBackend[_T]] = None,
2327
task_id_generator: Optional[Callable[[], str]] = None,
2428
qos: int = 10,
25-
loop: Optional[AbstractEventLoop] = None,
26-
max_channel_pool_size: int = 2,
27-
max_connection_pool_size: int = 10,
29+
loop: Optional[asyncio.AbstractEventLoop] = None,
2830
exchange_name: str = "taskiq",
2931
queue_name: str = "taskiq",
3032
declare_exchange: bool = True,
@@ -42,8 +44,6 @@ def __init__( # noqa: WPS211
4244
:param task_id_generator: custom task_id genertaor.
4345
:param qos: number of messages that worker can prefetch.
4446
:param loop: specific even loop.
45-
:param max_channel_pool_size: maximum number of channels for each connection.
46-
:param max_connection_pool_size: maximum number of connections in pool.
4747
:param exchange_name: name of exchange that used to send messages.
4848
:param queue_name: queue that used to get incoming messages.
4949
:param declare_exchange: whether you want to declare new exchange
@@ -112,9 +112,11 @@ async def kick(self, message: BrokerMessage) -> None:
112112
as the task_name.
113113
114114
115-
:raises ValueError: if startup wasn't awaited.
115+
:raises ValueError: if startup wasn't called.
116116
:param message: message to send.
117117
"""
118+
if self.write_channel is None:
119+
raise ValueError("Please run startup before kicking.")
118120
rmq_msg = Message(
119121
body=message.message.encode(),
120122
headers={
@@ -123,44 +125,36 @@ async def kick(self, message: BrokerMessage) -> None:
123125
**message.labels,
124126
},
125127
)
126-
if self.write_channel is None:
127-
raise ValueError("Please run startup before kicking.")
128128
exchange = await self.write_channel.get_exchange(
129129
self._exchange_name,
130130
ensure=False,
131131
)
132132
await exchange.publish(rmq_msg, routing_key=message.task_name)
133133

134-
async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
134+
async def listen(
135+
self,
136+
callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]],
137+
) -> None:
135138
"""
136139
Listen to queue.
137140
138-
This function listens to queue and yields
139-
new messages.
141+
This function listens to queue and calls
142+
callback on every new message.
140143
144+
:param callback: function to call on new message.
141145
:raises ValueError: if startup wasn't called.
142-
:yield: parsed broker messages.
143146
"""
147+
self.callback = callback
144148
if self.read_channel is None:
145149
raise ValueError("Call startup before starting listening.")
146-
await self.read_channel.set_qos(prefetch_count=0)
150+
await self.read_channel.set_qos(prefetch_count=self._qos)
147151
queue = await self.read_channel.get_queue(self._queue_name, ensure=False)
148-
async with queue.iterator() as queue_iter:
149-
async for rmq_message in queue_iter:
150-
async with rmq_message.process():
151-
try:
152-
yield BrokerMessage(
153-
task_id=rmq_message.headers.pop("task_id"),
154-
task_name=rmq_message.headers.pop("task_name"),
155-
message=rmq_message.body,
156-
labels=rmq_message.headers,
157-
)
158-
except (ValueError, LookupError) as exc:
159-
logger.debug(
160-
"Cannot read broker message %s",
161-
exc,
162-
exc_info=True,
163-
)
152+
await queue.consume(self.process_message)
153+
try: # noqa: WPS501
154+
# Wait until terminate
155+
await asyncio.Future()
156+
finally:
157+
await self.shutdown()
164158

165159
async def shutdown(self) -> None:
166160
"""Close all connections on shutdown."""
@@ -173,3 +167,29 @@ async def shutdown(self) -> None:
173167
await self.write_conn.close()
174168
if self.read_conn:
175169
await self.read_conn.close()
170+
171+
async def process_message(self, message: AbstractIncomingMessage) -> None:
172+
"""
173+
Process received message.
174+
175+
This function parses broker message and
176+
calls callback.
177+
178+
:param message: received message.
179+
"""
180+
async with message.process():
181+
try:
182+
broker_message = BrokerMessage(
183+
task_id=message.headers.pop("task_id"),
184+
task_name=message.headers.pop("task_name"),
185+
message=message.body,
186+
labels=message.headers,
187+
)
188+
except (ValueError, LookupError) as exc:
189+
logger.debug(
190+
"Cannot read broker message %s",
191+
exc,
192+
exc_info=True,
193+
)
194+
return
195+
await self.callback(broker_message)

taskiq_aio_pika/tests/test_broker.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from aio_pika import Channel, Message
66
from aio_pika.abc import AbstractExchange, AbstractQueue
77
from aio_pika.exceptions import QueueEmpty
8+
from mock import AsyncMock
89
from taskiq import BrokerMessage
910

1011
from taskiq_aio_pika.broker import AioPikaBroker
@@ -34,9 +35,11 @@ async def test_kick_success(broker: AioPikaBroker) -> None:
3435

3536
await broker.kick(sent)
3637

37-
async for inc_msg in broker.listen():
38-
message = inc_msg
39-
break
38+
callback = AsyncMock()
39+
40+
with pytest.raises(asyncio.TimeoutError):
41+
await asyncio.wait_for(broker.listen(callback), timeout=0.4)
42+
message = callback.call_args_list[0].args[0]
4043

4144
assert message == sent
4245

@@ -97,9 +100,11 @@ async def test_listen(broker: AioPikaBroker, exchange: AbstractExchange) -> None
97100
routing_key="task_name",
98101
)
99102

100-
async for inc_message in broker.listen():
101-
message = inc_message
102-
break
103+
callback = AsyncMock()
104+
105+
with pytest.raises(asyncio.TimeoutError):
106+
await asyncio.wait_for(broker.listen(callback), timeout=0.4)
107+
message = callback.call_args_list[0].args[0]
103108

104109
assert message.message == "test_message"
105110
assert message.labels == {"label1": "label_val"}
@@ -124,11 +129,12 @@ async def test_wrong_format(
124129
Message(b"wrong"),
125130
routing_key=queue.name,
126131
)
132+
callback = AsyncMock()
127133

128134
with pytest.raises(asyncio.TimeoutError):
129135
await asyncio.wait_for(
130-
broker.listen().__anext__(),
131-
0.2, # noqa: WPS432
136+
broker.listen(callback=callback),
137+
0.4,
132138
)
133139

134140
with pytest.raises(QueueEmpty):

0 commit comments

Comments
 (0)