1
- from typing import Any , AsyncGenerator , Dict , Optional , TypeVar
1
+ from asyncio import AbstractEventLoop
2
+ from typing import Any , AsyncGenerator , Optional , TypeVar
3
+
4
+ from aio_pika import Channel , ExchangeType , Message , connect_robust
5
+ from aio_pika .abc import AbstractChannel , AbstractRobustConnection
6
+ from aio_pika .pool import Pool
2
7
from taskiq .abc .broker import AsyncBroker
3
8
from taskiq .abc .result_backend import AsyncResultBackend
4
- from taskiq .message import TaskiqMessage
5
- from aio_pika .abc import AbstractRobustConnection
6
- from aio_pika .pool import Pool
7
- from asyncio import AbstractEventLoop
8
- from aio_pika import connect_robust , Channel , Message , ExchangeType
9
+ from taskiq .message import BrokerMessage
9
10
10
11
_T = TypeVar ("_T" )
11
12
@@ -35,7 +36,7 @@ async def _get_rmq_connection() -> AbstractRobustConnection:
35
36
loop = loop ,
36
37
)
37
38
38
- async def get_channel () -> Channel :
39
+ async def get_channel () -> AbstractChannel :
39
40
async with self .connection_pool .acquire () as connection :
40
41
return await connection .channel ()
41
42
@@ -62,30 +63,30 @@ async def startup(self) -> None:
62
63
queue = await channel .declare_queue (self .queue_name )
63
64
await queue .bind (exchange = exchange , routing_key = "*" )
64
65
65
- async def kick (self , message : TaskiqMessage ) -> None :
66
+ async def kick (self , message : BrokerMessage ) -> None :
66
67
rmq_msg = Message (
67
- body = message .json ().encode (),
68
- content_type = "application/json" ,
68
+ body = message .message .encode (),
69
69
headers = {
70
70
"task_id" : message .task_id ,
71
71
"task_name" : message .task_name ,
72
+ ** message .headers ,
72
73
},
73
74
)
74
75
async with self .channel_pool .acquire () as channel :
75
76
exchange = await channel .get_exchange (self .exchange_name , ensure = False )
76
77
await exchange .publish (rmq_msg , routing_key = message .task_id )
77
78
78
- async def listen (self ) -> AsyncGenerator [TaskiqMessage , None ]:
79
+ async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
79
80
async with self .channel_pool .acquire () as channel :
80
81
await channel .set_qos (prefetch_count = self .qos )
81
82
queue = await channel .get_queue (self .queue_name , ensure = False )
82
83
async with queue .iterator () as queue_iter :
83
84
async for rmq_message in queue_iter :
84
85
async with rmq_message .process ():
85
86
try :
86
- yield TaskiqMessage .parse_raw (
87
+ yield BrokerMessage .parse_raw (
87
88
rmq_message .body ,
88
- content_type = rmq_message .content_type ,
89
+ content_type = rmq_message .content_type or "" ,
89
90
)
90
91
except ValueError :
91
92
continue
0 commit comments