Skip to content

Commit bc45417

Browse files
pinguingmanzazymking
authored andcommitted
implement redis cluster broker
1 parent 7ba249f commit bc45417

File tree

3 files changed

+112
-1
lines changed

3 files changed

+112
-1
lines changed

taskiq_redis/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
RedisAsyncResultBackend,
55
)
66
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
7+
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
78
from taskiq_redis.schedule_source import RedisScheduleSource
89

910
__all__ = [
1011
"RedisAsyncClusterResultBackend",
1112
"RedisAsyncResultBackend",
1213
"ListQueueBroker",
1314
"PubSubBroker",
15+
"ListQueueClusterBroker",
1416
"RedisScheduleSource",
1517
]

taskiq_redis/redis_cluster_broker.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
2+
3+
from redis.asyncio import RedisCluster
4+
from taskiq.abc.broker import AsyncBroker
5+
from taskiq.abc.result_backend import AsyncResultBackend
6+
from taskiq.message import BrokerMessage
7+
8+
_T = TypeVar("_T")
9+
10+
11+
class BaseRedisClusterBroker(AsyncBroker):
12+
"""Base broker that works with Redis Cluster."""
13+
14+
def __init__(
15+
self,
16+
url: str,
17+
task_id_generator: Optional[Callable[[], str]] = None,
18+
result_backend: Optional[AsyncResultBackend[_T]] = None,
19+
queue_name: str = "taskiq",
20+
max_connection_pool_size: int = 2**31,
21+
**connection_kwargs: Any,
22+
) -> None:
23+
"""
24+
Constructs a new broker.
25+
26+
:param url: url to redis.
27+
:param task_id_generator: custom task_id generator.
28+
:param result_backend: custom result backend.
29+
:param queue_name: name for a list in redis.
30+
:param max_connection_pool_size: maximum number of connections in pool.
31+
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
32+
"""
33+
super().__init__(
34+
result_backend=result_backend,
35+
task_id_generator=task_id_generator,
36+
)
37+
38+
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
39+
url=url,
40+
max_connections=max_connection_pool_size,
41+
**connection_kwargs,
42+
)
43+
44+
self.queue_name = queue_name
45+
46+
async def shutdown(self) -> None:
47+
"""Closes redis connection pool."""
48+
await self.redis.aclose() # type: ignore[attr-defined]
49+
await super().shutdown()
50+
51+
52+
class ListQueueClusterBroker(BaseRedisClusterBroker):
53+
"""Broker that works with Redis Cluster and distributes tasks between workers."""
54+
55+
async def kick(self, message: BrokerMessage) -> None:
56+
"""
57+
Put a message in a list.
58+
59+
This method appends a message to the list of all messages.
60+
61+
:param message: message to append.
62+
"""
63+
await self.redis.lpush(self.queue_name, message.message)
64+
65+
async def listen(self) -> AsyncGenerator[bytes, None]:
66+
"""
67+
Listen redis queue for new messages.
68+
69+
This function listens to the queue
70+
and yields new messages if they have BrokerMessage type.
71+
72+
:yields: broker messages.
73+
"""
74+
redis_brpop_data_position = 1
75+
while True:
76+
yield (await self.redis.brpop([self.queue_name]))[
77+
redis_brpop_data_position
78+
]

tests/test_broker.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from taskiq import AckableMessage, AsyncBroker, BrokerMessage
77

8-
from taskiq_redis import ListQueueBroker, PubSubBroker
8+
from taskiq_redis import ListQueueBroker, PubSubBroker, ListQueueClusterBroker
99

1010

1111
def test_no_url_should_raise_typeerror() -> None:
@@ -96,3 +96,34 @@ async def test_list_queue_broker(
9696
worker1_task.cancel()
9797
worker2_task.cancel()
9898
await broker.shutdown()
99+
100+
101+
@pytest.mark.anyio
102+
async def test_list_queue_cluster_broker(
103+
valid_broker_message: BrokerMessage,
104+
redis_cluster_url: str,
105+
) -> None:
106+
"""
107+
Test that messages are published and read correctly by ListQueueClusterBroker.
108+
109+
We create two workers that listen and send a message to them.
110+
Expect only one worker to receive the same message we sent.
111+
"""
112+
113+
print(f"redis_cluster_url: {redis_cluster_url}")
114+
broker = ListQueueClusterBroker(
115+
url=redis_cluster_url, queue_name=uuid.uuid4().hex
116+
)
117+
worker1_task = asyncio.create_task(get_message(broker))
118+
worker2_task = asyncio.create_task(get_message(broker))
119+
await asyncio.sleep(0.3)
120+
121+
await broker.kick(valid_broker_message)
122+
await asyncio.sleep(0.3)
123+
124+
assert worker1_task.done() != worker2_task.done()
125+
message = worker1_task.result() if worker1_task.done() else worker2_task.result()
126+
assert message == valid_broker_message.message
127+
worker1_task.cancel()
128+
worker2_task.cancel()
129+
await broker.shutdown()

0 commit comments

Comments
 (0)