1
1
from asyncio import AbstractEventLoop
2
2
from logging import getLogger
3
- from typing import Any , AsyncGenerator , Optional , TypeVar
3
+ from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
4
4
5
5
from aio_pika import Channel , ExchangeType , Message , connect_robust
6
6
from aio_pika .abc import AbstractChannel , AbstractRobustConnection
@@ -18,6 +18,7 @@ class AioPikaBroker(AsyncBroker):
18
18
def __init__ (
19
19
self ,
20
20
result_backend : Optional [AsyncResultBackend [_T ]] = None ,
21
+ task_id_generator : Optional [Callable [[], str ]] = None ,
21
22
qos : int = 10 ,
22
23
loop : Optional [AbstractEventLoop ] = None ,
23
24
max_channel_pool_size : int = 2 ,
@@ -28,7 +29,7 @@ def __init__(
28
29
* connection_args : Any ,
29
30
** connection_kwargs : Any ,
30
31
) -> None :
31
- super ().__init__ (result_backend )
32
+ super ().__init__ (result_backend , task_id_generator )
32
33
33
34
async def _get_rmq_connection () -> AbstractRobustConnection :
34
35
return await connect_robust (* connection_args , ** connection_kwargs )
@@ -72,7 +73,7 @@ async def kick(self, message: BrokerMessage) -> None:
72
73
headers = {
73
74
"task_id" : message .task_id ,
74
75
"task_name" : message .task_name ,
75
- ** message .headers ,
76
+ ** message .labels ,
76
77
},
77
78
)
78
79
async with self .channel_pool .acquire () as channel :
@@ -88,10 +89,10 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
88
89
async with rmq_message .process ():
89
90
try :
90
91
yield BrokerMessage (
91
- task_id = rmq_message .headers [ "task_id" ] ,
92
- task_name = rmq_message .headers [ "task_name" ] ,
92
+ task_id = rmq_message .headers . pop ( "task_id" ) ,
93
+ task_name = rmq_message .headers . pop ( "task_name" ) ,
93
94
message = rmq_message .body ,
94
- headers = rmq_message .headers ,
95
+ labels = rmq_message .headers ,
95
96
)
96
97
except (ValueError , LookupError ) as exc :
97
98
logger .debug (
@@ -101,4 +102,5 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
101
102
)
102
103
103
104
async def shutdown (self ) -> None :
105
+ await super ().shutdown ()
104
106
await self .connection_pool .close ()
0 commit comments