Skip to content

Commit 7535c6c

Browse files
committed
Updated broker.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent a7171e2 commit 7535c6c

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

taskiq_aio_pika/taskiq/brokers/aio_pika_broker.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from asyncio import AbstractEventLoop
22
from logging import getLogger
3-
from typing import Any, AsyncGenerator, Optional, TypeVar
3+
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
44

55
from aio_pika import Channel, ExchangeType, Message, connect_robust
66
from aio_pika.abc import AbstractChannel, AbstractRobustConnection
@@ -18,6 +18,7 @@ class AioPikaBroker(AsyncBroker):
1818
def __init__(
1919
self,
2020
result_backend: Optional[AsyncResultBackend[_T]] = None,
21+
task_id_generator: Optional[Callable[[], str]] = None,
2122
qos: int = 10,
2223
loop: Optional[AbstractEventLoop] = None,
2324
max_channel_pool_size: int = 2,
@@ -28,7 +29,7 @@ def __init__(
2829
*connection_args: Any,
2930
**connection_kwargs: Any,
3031
) -> None:
31-
super().__init__(result_backend)
32+
super().__init__(result_backend, task_id_generator)
3233

3334
async def _get_rmq_connection() -> AbstractRobustConnection:
3435
return await connect_robust(*connection_args, **connection_kwargs)
@@ -72,7 +73,7 @@ async def kick(self, message: BrokerMessage) -> None:
7273
headers={
7374
"task_id": message.task_id,
7475
"task_name": message.task_name,
75-
**message.headers,
76+
**message.labels,
7677
},
7778
)
7879
async with self.channel_pool.acquire() as channel:
@@ -88,10 +89,10 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
8889
async with rmq_message.process():
8990
try:
9091
yield BrokerMessage(
91-
task_id=rmq_message.headers["task_id"],
92-
task_name=rmq_message.headers["task_name"],
92+
task_id=rmq_message.headers.pop("task_id"),
93+
task_name=rmq_message.headers.pop("task_name"),
9394
message=rmq_message.body,
94-
headers=rmq_message.headers,
95+
labels=rmq_message.headers,
9596
)
9697
except (ValueError, LookupError) as exc:
9798
logger.debug(
@@ -101,4 +102,5 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
101102
)
102103

103104
async def shutdown(self) -> None:
105+
await super().shutdown()
104106
await self.connection_pool.close()

0 commit comments

Comments
 (0)