1
1
import pickle
2
+ from abc import abstractmethod
2
3
from logging import getLogger
3
4
from typing import Any , AsyncGenerator , Callable , Optional , TypeVar
4
5
12
13
logger = getLogger ("taskiq.redis_broker" )
13
14
14
15
15
- class RedisBroker (AsyncBroker ):
16
- """Broker that works with Redis."""
16
+ class BaseRedisBroker (AsyncBroker ):
17
+ """Base broker that works with Redis."""
17
18
18
19
def __init__ (
19
20
self ,
@@ -44,31 +45,12 @@ def __init__(
44
45
max_connections = max_connection_pool_size ,
45
46
** connection_kwargs ,
46
47
)
47
-
48
- self .redis_pubsub_channel = queue_name
48
+ self .queue_name = queue_name
49
49
50
50
async def shutdown (self ) -> None :
51
51
"""Closes redis connection pool."""
52
52
await self .connection_pool .disconnect ()
53
53
54
- async def kick (self , message : BrokerMessage ) -> None :
55
- """
56
- Sends a message to the redis broker list.
57
-
58
- This function constructs message for redis
59
- and sends it.
60
-
61
- The message is pickled dict object with message,
62
- task_id, task_name and labels.
63
-
64
- :param message: message to send.
65
- """
66
- async with Redis (connection_pool = self .connection_pool ) as redis_conn :
67
- await redis_conn .publish (
68
- self .redis_pubsub_channel ,
69
- pickle .dumps (message ),
70
- )
71
-
72
54
async def listen (self ) -> AsyncGenerator [BrokerMessage , None ]:
73
55
"""
74
56
Listen redis queue for new messages.
@@ -78,24 +60,60 @@ async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
78
60
79
61
:yields: broker messages.
80
62
"""
63
+ async for message in self ._listen_to_raw_messages ():
64
+ try :
65
+ redis_message = pickle .loads (message )
66
+ if isinstance (redis_message , BrokerMessage ):
67
+ yield redis_message
68
+ except (
69
+ TypeError ,
70
+ AttributeError ,
71
+ pickle .UnpicklingError ,
72
+ ) as exc :
73
+ logger .debug (
74
+ "Cannot read broker message %s" ,
75
+ exc ,
76
+ exc_info = True ,
77
+ )
78
+
79
+ @abstractmethod
80
+ async def _listen_to_raw_messages (self ) -> AsyncGenerator [bytes , None ]:
81
+ """
82
+ Generator for reading raw data from Redis.
83
+
84
+ :yields: raw data.
85
+ """
86
+ yield # type: ignore
87
+
88
+
89
+ class PubSubBroker (BaseRedisBroker ):
90
+ """Broker that works with Redis and broadcasts tasks to all workers."""
91
+
92
+ async def kick (self , message : BrokerMessage ) -> None : # noqa: D102
93
+ async with Redis (connection_pool = self .connection_pool ) as redis_conn :
94
+ await redis_conn .publish (self .queue_name , pickle .dumps (message ))
95
+
96
+ async def _listen_to_raw_messages (self ) -> AsyncGenerator [bytes , None ]:
81
97
async with Redis (connection_pool = self .connection_pool ) as redis_conn :
82
98
redis_pubsub_channel = redis_conn .pubsub ()
83
- await redis_pubsub_channel .subscribe (self .redis_pubsub_channel )
99
+ await redis_pubsub_channel .subscribe (self .queue_name )
84
100
async for message in redis_pubsub_channel .listen ():
85
- if message :
86
- try :
87
- redis_message = pickle .loads (
88
- message ["data" ],
89
- )
90
- if isinstance (redis_message , BrokerMessage ):
91
- yield redis_message
92
- except (
93
- TypeError ,
94
- AttributeError ,
95
- pickle .UnpicklingError ,
96
- ) as exc :
97
- logger .debug (
98
- "Cannot read broker message %s" ,
99
- exc ,
100
- exc_info = True ,
101
- )
101
+ if not message :
102
+ continue
103
+ yield message ["data" ]
104
+
105
+
106
+ class ListQueueBroker (BaseRedisBroker ):
107
+ """Broker that works with Redis and distributes tasks between workers."""
108
+
109
+ async def kick (self , message : BrokerMessage ) -> None : # noqa: D102
110
+ async with Redis (connection_pool = self .connection_pool ) as redis_conn :
111
+ await redis_conn .lpush (self .queue_name , pickle .dumps (message ))
112
+
113
+ async def _listen_to_raw_messages (self ) -> AsyncGenerator [bytes , None ]:
114
+ redis_brpop_data_position = 1
115
+ async with Redis (connection_pool = self .connection_pool ) as redis_conn :
116
+ while True : # noqa: WPS457
117
+ yield (await redis_conn .brpop (self .queue_name ))[
118
+ redis_brpop_data_position
119
+ ]
0 commit comments