diff --git a/README.md b/README.md index 99e4bdb..19e9847 100644 --- a/README.md +++ b/README.md @@ -157,4 +157,3 @@ broker = AioPikaBroker( ``` This will ensure that the queue is created with your custom arguments, in addition to the broker's defaults. - diff --git a/taskiq_aio_pika/broker.py b/taskiq_aio_pika/broker.py index a3f3939..e4254fb 100644 --- a/taskiq_aio_pika/broker.py +++ b/taskiq_aio_pika/broker.py @@ -266,7 +266,14 @@ async def kick(self, message: BrokerMessage) -> None: self._exchange_name, ensure=False, ) - await exchange.publish(rmq_message, routing_key=message.task_name) + + routing_key = message.task_name + + # Because direct exchange uses exact routing key for routing + if self._exchange_type == ExchangeType.DIRECT: + routing_key = self._routing_key + + await exchange.publish(rmq_message, routing_key=routing_key) elif self._delayed_message_exchange_plugin: rmq_message.headers["x-delay"] = int(delay * 1000) exchange = await self.write_channel.get_exchange( diff --git a/tests/test_broker.py b/tests/test_broker.py index e1c7119..1c462ac 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -2,7 +2,7 @@ import uuid import pytest -from aio_pika import Channel, Message +from aio_pika import Channel, ExchangeType, Message from aio_pika.exceptions import QueueEmpty from taskiq import AckableMessage, BrokerMessage from taskiq.utils import maybe_awaitable @@ -10,15 +10,16 @@ from taskiq_aio_pika.broker import AioPikaBroker -async def get_first_task(broker: AioPikaBroker) -> AckableMessage: # type: ignore +async def get_first_task(broker: AioPikaBroker) -> AckableMessage: """ Get first message from the queue. :param broker: async message broker. :return: first message from listen method """ - async for message in broker.listen(): # noqa: RET503 + async for message in broker.listen(): return message + return None # type: ignore @pytest.mark.anyio @@ -219,3 +220,50 @@ async def test_delayed_message_with_plugin( await asyncio.sleep(2) assert await main_queue.get() + + +@pytest.mark.anyio +async def test_direct_kick( + broker: AioPikaBroker, + test_channel: Channel, + queue_name: str, + exchange_name: str, +) -> None: + """ + Test that messages are published and read correctly. + + We kick the message and then try to listen to the queue, + and check that message we got is the same as we sent. + """ + queue = await test_channel.get_queue(queue_name) + exchange = await test_channel.get_exchange(exchange_name) + await queue.delete() + await exchange.delete() + + broker._declare_exchange = True + broker._exchange_type = ExchangeType.DIRECT + broker._routing_key = "direct_routing_key" + + await broker.startup() + + await test_channel.get_queue(queue_name, ensure=True) + await test_channel.get_exchange(exchange_name, ensure=True) + + task_id = uuid.uuid4().hex + task_name = uuid.uuid4().hex + + sent = BrokerMessage( + task_id=task_id, + task_name=task_name, + message=b"my_msg", + labels={ + "label1": "val1", + }, + ) + + await broker.kick(sent) + + message = await asyncio.wait_for(get_first_task(broker), timeout=0.4) + + assert message.data == sent.message + await maybe_awaitable(message.ack())