Skip to content

Commit a9b7f9d

Browse files
committed
feat: Added RedisSentinelAsyncResultBackend.
1 parent 4c3343f commit a9b7f9d

File tree

3 files changed

+312
-3
lines changed

3 files changed

+312
-3
lines changed

taskiq_redis/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from taskiq_redis.redis_backend import (
33
RedisAsyncClusterResultBackend,
44
RedisAsyncResultBackend,
5+
RedisAsyncSentinelResultBackend,
56
)
67
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
78
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
@@ -14,6 +15,7 @@
1415
__all__ = [
1516
"RedisAsyncClusterResultBackend",
1617
"RedisAsyncResultBackend",
18+
"RedisAsyncSentinelResultBackend",
1719
"ListQueueBroker",
1820
"PubSubBroker",
1921
"ListQueueClusterBroker",

taskiq_redis/redis_backend.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
import pickle
2-
from typing import Any, Dict, Optional, TypeVar, Union
2+
import sys
3+
from contextlib import asynccontextmanager
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
AsyncIterator,
8+
Dict,
9+
List,
10+
Optional,
11+
Tuple,
12+
TypeVar,
13+
Union,
14+
)
315

4-
from redis.asyncio import BlockingConnectionPool, Redis
16+
from redis.asyncio import BlockingConnectionPool, Redis, Sentinel
517
from redis.asyncio.cluster import RedisCluster
618
from taskiq import AsyncResultBackend
719
from taskiq.abc.result_backend import TaskiqResult
@@ -12,6 +24,16 @@
1224
ResultIsMissingError,
1325
)
1426

27+
if sys.version_info >= (3, 10):
28+
from typing import TypeAlias
29+
else:
30+
from typing_extensions import TypeAlias
31+
32+
if TYPE_CHECKING:
33+
_Redis: TypeAlias = Redis[bytes]
34+
else:
35+
_Redis: TypeAlias = Redis
36+
1537
_ReturnType = TypeVar("_ReturnType")
1638

1739

@@ -267,3 +289,138 @@ async def get_result(
267289
taskiq_result.log = None
268290

269291
return taskiq_result
292+
293+
294+
class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
295+
"""Async result based on redis sentinel."""
296+
297+
def __init__(
298+
self,
299+
sentinels: List[Tuple[str, int]],
300+
master_name: str,
301+
keep_results: bool = True,
302+
result_ex_time: Optional[int] = None,
303+
result_px_time: Optional[int] = None,
304+
min_other_sentinels: int = 0,
305+
sentinel_kwargs: Optional[Any] = None,
306+
**connection_kwargs: Any,
307+
) -> None:
308+
"""
309+
Constructs a new result backend.
310+
311+
:param sentinels: list of sentinel host and ports pairs.
312+
:param master_name: sentinel master name.
313+
:param keep_results: flag to not remove results from Redis after reading.
314+
:param result_ex_time: expire time in seconds for result.
315+
:param result_px_time: expire time in milliseconds for result.
316+
:param max_connection_pool_size: maximum number of connections in pool.
317+
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
318+
319+
:raises DuplicateExpireTimeSelectedError: if result_ex_time
320+
and result_px_time are selected.
321+
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
322+
and result_px_time are equal zero.
323+
"""
324+
self.sentinel = Sentinel(
325+
sentinels=sentinels,
326+
min_other_sentinels=min_other_sentinels,
327+
sentinel_kwargs=sentinel_kwargs,
328+
**connection_kwargs,
329+
)
330+
self.master_name = master_name
331+
self.keep_results = keep_results
332+
self.result_ex_time = result_ex_time
333+
self.result_px_time = result_px_time
334+
335+
unavailable_conditions = any(
336+
(
337+
self.result_ex_time is not None and self.result_ex_time <= 0,
338+
self.result_px_time is not None and self.result_px_time <= 0,
339+
),
340+
)
341+
if unavailable_conditions:
342+
raise ExpireTimeMustBeMoreThanZeroError(
343+
"You must select one expire time param and it must be more than zero.",
344+
)
345+
346+
if self.result_ex_time and self.result_px_time:
347+
raise DuplicateExpireTimeSelectedError(
348+
"Choose either result_ex_time or result_px_time.",
349+
)
350+
351+
@asynccontextmanager
352+
async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
353+
async with self.sentinel.master_for(self.master_name) as redis_conn:
354+
yield redis_conn
355+
356+
async def set_result(
357+
self,
358+
task_id: str,
359+
result: TaskiqResult[_ReturnType],
360+
) -> None:
361+
"""
362+
Sets task result in redis.
363+
364+
Dumps TaskiqResult instance into the bytes and writes
365+
it to redis.
366+
367+
:param task_id: ID of the task.
368+
:param result: TaskiqResult instance.
369+
"""
370+
redis_set_params: Dict[str, Union[str, bytes, int]] = {
371+
"name": task_id,
372+
"value": pickle.dumps(result),
373+
}
374+
if self.result_ex_time:
375+
redis_set_params["ex"] = self.result_ex_time
376+
elif self.result_px_time:
377+
redis_set_params["px"] = self.result_px_time
378+
379+
async with self._acquire_master_conn() as redis:
380+
await redis.set(**redis_set_params) # type: ignore
381+
382+
async def is_result_ready(self, task_id: str) -> bool:
383+
"""
384+
Returns whether the result is ready.
385+
386+
:param task_id: ID of the task.
387+
388+
:returns: True if the result is ready else False.
389+
"""
390+
async with self._acquire_master_conn() as redis:
391+
return bool(await redis.exists(task_id))
392+
393+
async def get_result(
394+
self,
395+
task_id: str,
396+
with_logs: bool = False,
397+
) -> TaskiqResult[_ReturnType]:
398+
"""
399+
Gets result from the task.
400+
401+
:param task_id: task's id.
402+
:param with_logs: if True it will download task's logs.
403+
:raises ResultIsMissingError: if there is no result when trying to get it.
404+
:return: task's return value.
405+
"""
406+
async with self._acquire_master_conn() as redis:
407+
if self.keep_results:
408+
result_value = await redis.get(
409+
name=task_id,
410+
)
411+
else:
412+
result_value = await redis.getdel(
413+
name=task_id,
414+
)
415+
416+
if result_value is None:
417+
raise ResultIsMissingError
418+
419+
taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
420+
result_value,
421+
)
422+
423+
if not with_logs:
424+
taskiq_result.log = None
425+
426+
return taskiq_result

tests/test_result_backend.py

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
import asyncio
22
import uuid
3+
from typing import List, Tuple
34

45
import pytest
56
from taskiq import TaskiqResult
67

7-
from taskiq_redis import RedisAsyncClusterResultBackend, RedisAsyncResultBackend
8+
from taskiq_redis import (
9+
RedisAsyncClusterResultBackend,
10+
RedisAsyncResultBackend,
11+
RedisAsyncSentinelResultBackend,
12+
)
813
from taskiq_redis.exceptions import ResultIsMissingError
914

1015

@@ -288,3 +293,148 @@ async def test_keep_results_after_reading_cluster(redis_cluster_url: str) -> Non
288293
res2 = await result_backend.get_result(task_id=task_id)
289294
assert res1 == res2
290295
await result_backend.shutdown()
296+
297+
298+
@pytest.mark.anyio
299+
async def test_set_result_success_sentinel(
300+
redis_sentinels: List[Tuple[str, int]],
301+
redis_sentinel_master_name: str,
302+
) -> None:
303+
"""
304+
Tests that results can be set without errors in cluster mode.
305+
306+
:param redis_sentinels: list of host and port pairs.
307+
:param redis_sentinel_master_name: redis sentinel master name string.
308+
"""
309+
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
310+
sentinels=redis_sentinels,
311+
master_name=redis_sentinel_master_name,
312+
)
313+
task_id = uuid.uuid4().hex
314+
result: "TaskiqResult[int]" = TaskiqResult(
315+
is_err=True,
316+
log="My Log",
317+
return_value=11,
318+
execution_time=112.2,
319+
)
320+
await result_backend.set_result(
321+
task_id=task_id,
322+
result=result,
323+
)
324+
325+
fetched_result = await result_backend.get_result(
326+
task_id=task_id,
327+
with_logs=True,
328+
)
329+
assert fetched_result.log == "My Log"
330+
assert fetched_result.return_value == 11
331+
assert fetched_result.execution_time == 112.2
332+
assert fetched_result.is_err
333+
await result_backend.shutdown()
334+
335+
336+
@pytest.mark.anyio
337+
async def test_fetch_without_logs_sentinel(
338+
redis_sentinels: List[Tuple[str, int]],
339+
redis_sentinel_master_name: str,
340+
) -> None:
341+
"""
342+
Check if fetching value without logs works fine.
343+
344+
:param redis_sentinels: list of host and port pairs.
345+
:param redis_sentinel_master_name: redis sentinel master name string.
346+
"""
347+
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
348+
sentinels=redis_sentinels,
349+
master_name=redis_sentinel_master_name,
350+
)
351+
task_id = uuid.uuid4().hex
352+
result: "TaskiqResult[int]" = TaskiqResult(
353+
is_err=True,
354+
log="My Log",
355+
return_value=11,
356+
execution_time=112.2,
357+
)
358+
await result_backend.set_result(
359+
task_id=task_id,
360+
result=result,
361+
)
362+
363+
fetched_result = await result_backend.get_result(
364+
task_id=task_id,
365+
with_logs=False,
366+
)
367+
assert fetched_result.log is None
368+
assert fetched_result.return_value == 11
369+
assert fetched_result.execution_time == 112.2
370+
assert fetched_result.is_err
371+
await result_backend.shutdown()
372+
373+
374+
@pytest.mark.anyio
375+
async def test_remove_results_after_reading_sentinel(
376+
redis_sentinels: List[Tuple[str, int]],
377+
redis_sentinel_master_name: str,
378+
) -> None:
379+
"""
380+
Check if removing results after reading works fine.
381+
382+
:param redis_sentinels: list of host and port pairs.
383+
:param redis_sentinel_master_name: redis sentinel master name string.
384+
"""
385+
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
386+
sentinels=redis_sentinels,
387+
master_name=redis_sentinel_master_name,
388+
keep_results=False,
389+
)
390+
task_id = uuid.uuid4().hex
391+
result: "TaskiqResult[int]" = TaskiqResult(
392+
is_err=True,
393+
log="My Log",
394+
return_value=11,
395+
execution_time=112.2,
396+
)
397+
await result_backend.set_result(
398+
task_id=task_id,
399+
result=result,
400+
)
401+
402+
await result_backend.get_result(task_id=task_id)
403+
with pytest.raises(ResultIsMissingError):
404+
await result_backend.get_result(task_id=task_id)
405+
406+
await result_backend.shutdown()
407+
408+
409+
@pytest.mark.anyio
410+
async def test_keep_results_after_reading_sentinel(
411+
redis_sentinels: List[Tuple[str, int]],
412+
redis_sentinel_master_name: str,
413+
) -> None:
414+
"""
415+
Check if keeping results after reading works fine.
416+
417+
:param redis_sentinels: list of host and port pairs.
418+
:param redis_sentinel_master_name: redis sentinel master name string.
419+
"""
420+
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
421+
sentinels=redis_sentinels,
422+
master_name=redis_sentinel_master_name,
423+
keep_results=True,
424+
)
425+
task_id = uuid.uuid4().hex
426+
result: "TaskiqResult[int]" = TaskiqResult(
427+
is_err=True,
428+
log="My Log",
429+
return_value=11,
430+
execution_time=112.2,
431+
)
432+
await result_backend.set_result(
433+
task_id=task_id,
434+
result=result,
435+
)
436+
437+
res1 = await result_backend.get_result(task_id=task_id)
438+
res2 = await result_backend.get_result(task_id=task_id)
439+
assert res1 == res2
440+
await result_backend.shutdown()

0 commit comments

Comments
 (0)