Skip to content

Commit 4192945

Browse files
committed
Updated broker's interface.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent ab0aa8a commit 4192945

File tree

3 files changed

+42
-66
lines changed

3 files changed

+42
-66
lines changed

taskiq_aio_pika/broker.py

Lines changed: 28 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
import asyncio
22
from datetime import timedelta
33
from logging import getLogger
4-
from typing import Any, Callable, Coroutine, Dict, Optional, TypeVar
4+
from typing import Any, AsyncGenerator, Callable, Dict, Optional, TypeVar
55

66
from aio_pika import DeliveryMode, ExchangeType, Message, connect_robust
7-
from aio_pika.abc import (
8-
AbstractChannel,
9-
AbstractIncomingMessage,
10-
AbstractQueue,
11-
AbstractRobustConnection,
12-
)
7+
from aio_pika.abc import AbstractChannel, AbstractQueue, AbstractRobustConnection
138
from taskiq.abc.broker import AsyncBroker
149
from taskiq.abc.result_backend import AsyncResultBackend
1510
from taskiq.message import BrokerMessage
@@ -219,30 +214,42 @@ async def kick(self, message: BrokerMessage) -> None:
219214
routing_key=self._delay_queue_name,
220215
)
221216

222-
async def listen(
223-
self,
224-
callback: Callable[[BrokerMessage], Coroutine[Any, Any, None]],
225-
) -> None:
217+
async def listen(self) -> AsyncGenerator[BrokerMessage, None]: # noqa: WPS210
226218
"""
227219
Listen to queue.
228220
229-
This function listens to queue and calls
230-
callback on every new message.
221+
This function listens to queue and
222+
yields every new message.
231223
232-
:param callback: function to call on new message.
224+
:yields: parsed broker message.
233225
:raises ValueError: if startup wasn't called.
234226
"""
235-
self.callback = callback
236227
if self.read_channel is None:
237228
raise ValueError("Call startup before starting listening.")
238229
await self.read_channel.set_qos(prefetch_count=self._qos)
239230
queue = await self.declare_queues(self.read_channel)
240-
await queue.consume(self.process_message)
241-
try: # noqa: WPS501
242-
# Wait until terminate
243-
await asyncio.Future()
244-
finally:
245-
await self.shutdown()
231+
async with queue.iterator() as iterator:
232+
async for message in iterator:
233+
async with message.process():
234+
headers = {}
235+
for header_name, header_value in message.headers.items():
236+
headers[header_name] = str(header_value)
237+
try:
238+
broker_message = BrokerMessage(
239+
task_id=headers.pop("task_id"),
240+
task_name=headers.pop("task_name"),
241+
message=message.body,
242+
labels=headers,
243+
)
244+
except (ValueError, LookupError) as exc:
245+
logger.warning(
246+
"Cannot read broker message %s",
247+
exc,
248+
exc_info=True,
249+
)
250+
continue
251+
252+
yield broker_message
246253

247254
async def shutdown(self) -> None:
248255
"""Close all connections on shutdown."""
@@ -255,32 +262,3 @@ async def shutdown(self) -> None:
255262
await self.write_conn.close()
256263
if self.read_conn:
257264
await self.read_conn.close()
258-
259-
async def process_message(self, message: AbstractIncomingMessage) -> None:
260-
"""
261-
Process received message.
262-
263-
This function parses broker message and
264-
calls callback.
265-
266-
:param message: received message.
267-
"""
268-
async with message.process():
269-
headers = {}
270-
for header_name, header_value in message.headers.items():
271-
headers[header_name] = str(header_value)
272-
try:
273-
broker_message = BrokerMessage(
274-
task_id=headers.pop("task_id"),
275-
task_name=headers.pop("task_name"),
276-
message=message.body,
277-
labels=headers,
278-
)
279-
except (ValueError, LookupError) as exc:
280-
logger.warning(
281-
"Cannot read broker message %s",
282-
exc,
283-
exc_info=True,
284-
)
285-
return
286-
await self.callback(broker_message)
File renamed without changes.

taskiq_aio_pika/tests/test_broker.py renamed to tests/test_broker.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,22 @@
44
import pytest
55
from aio_pika import Channel, Message
66
from aio_pika.exceptions import QueueEmpty
7-
from mock import AsyncMock
87
from taskiq import BrokerMessage
98

109
from taskiq_aio_pika.broker import AioPikaBroker
1110

1211

12+
async def get_first_task(broker: AioPikaBroker) -> BrokerMessage: # type: ignore
13+
"""
14+
Get first message from the queue.
15+
16+
:param broker: async message broker.
17+
:return: first message from listen method
18+
"""
19+
async for message in broker.listen():
20+
return message
21+
22+
1323
@pytest.mark.anyio
1424
async def test_kick_success(broker: AioPikaBroker) -> None:
1525
"""
@@ -34,11 +44,7 @@ async def test_kick_success(broker: AioPikaBroker) -> None:
3444

3545
await broker.kick(sent)
3646

37-
callback = AsyncMock()
38-
39-
with pytest.raises(asyncio.TimeoutError):
40-
await asyncio.wait_for(broker.listen(callback), timeout=0.4)
41-
message = callback.call_args_list[0].args[0]
47+
message = await asyncio.wait_for(get_first_task(broker), timeout=0.4)
4248

4349
assert message == sent
4450

@@ -103,11 +109,7 @@ async def test_listen(
103109
routing_key="task_name",
104110
)
105111

106-
callback = AsyncMock()
107-
108-
with pytest.raises(asyncio.TimeoutError):
109-
await asyncio.wait_for(broker.listen(callback), timeout=0.4)
110-
message = callback.call_args_list[0].args[0]
112+
message = await asyncio.wait_for(get_first_task(broker), timeout=0.4)
111113

112114
assert message.message == "test_message"
113115
assert message.labels == {"label1": "label_val"}
@@ -133,13 +135,9 @@ async def test_wrong_format(
133135
Message(b"wrong"),
134136
routing_key=queue_name,
135137
)
136-
callback = AsyncMock()
137138

138139
with pytest.raises(asyncio.TimeoutError):
139-
await asyncio.wait_for(
140-
broker.listen(callback=callback),
141-
0.4,
142-
)
140+
await asyncio.wait_for(get_first_task(broker), 0.4)
143141

144142
with pytest.raises(QueueEmpty):
145143
await queue.get()

0 commit comments

Comments
 (0)