Skip to content

Commit 1424f59

Browse files
committed
feat: Add get_progress and set_progress to redis result backend
Uses as standard suffix on the redis key (hardcoded as "__progress") to store progress results
1 parent 1982791 commit 1424f59

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

taskiq_redis/redis_backend.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from taskiq.abc.result_backend import TaskiqResult
2020
from taskiq.abc.serializer import TaskiqSerializer
2121
from taskiq.compat import model_dump, model_validate
22+
from taskiq.depends.progress_tracker import TaskProgress
2223
from taskiq.serializers import PickleSerializer
2324

2425
from taskiq_redis.exceptions import (
@@ -41,6 +42,8 @@
4142

4243
_ReturnType = TypeVar("_ReturnType")
4344

45+
PROGRESS_KEY_SUFFIX = "__progress"
46+
4447

4548
class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
4649
"""Async result based on redis."""
@@ -174,6 +177,52 @@ async def get_result(
174177

175178
return taskiq_result
176179

180+
async def set_progress(
181+
self,
182+
task_id: str,
183+
progress: TaskProgress[_ReturnType],
184+
) -> None:
185+
"""
186+
Sets task progress in redis.
187+
188+
Dumps TaskProgress instance into the bytes and writes
189+
it to redis with a standard suffix on the task_id as the key
190+
191+
:param task_id: ID of the task.
192+
:param result: task's TaskProgress instance.
193+
"""
194+
redis_set_params: Dict[str, Union[str, int, bytes]] = {
195+
"name": task_id + PROGRESS_KEY_SUFFIX,
196+
"value": self.serializer.dumpb(model_dump(progress)),
197+
}
198+
if self.result_ex_time:
199+
redis_set_params["ex"] = self.result_ex_time
200+
elif self.result_px_time:
201+
redis_set_params["px"] = self.result_px_time
202+
203+
async with Redis(connection_pool=self.redis_pool) as redis:
204+
await redis.set(**redis_set_params) # type: ignore
205+
206+
async def get_progress(self, task_id: str) -> TaskProgress[_ReturnType] | None:
207+
"""
208+
Gets progress results from the task.
209+
210+
:param task_id: task's id.
211+
:return: task's TaskProgress instance.
212+
"""
213+
async with Redis(connection_pool=self.redis_pool) as redis:
214+
result_value = await redis.get(
215+
name=task_id + PROGRESS_KEY_SUFFIX,
216+
)
217+
218+
if result_value is None:
219+
return None
220+
221+
return model_validate(
222+
TaskProgress[_ReturnType],
223+
self.serializer.loadb(result_value),
224+
)
225+
177226

178227
class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
179228
"""Async result backend based on redis cluster."""

tests/test_result_backend.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from taskiq import TaskiqResult
7+
from taskiq.depends.progress_tracker import TaskProgress, TaskState
78

89
from taskiq_redis import (
910
RedisAsyncClusterResultBackend,
@@ -438,3 +439,80 @@ async def test_keep_results_after_reading_sentinel(
438439
res2 = await result_backend.get_result(task_id=task_id)
439440
assert res1 == res2
440441
await result_backend.shutdown()
442+
443+
444+
@pytest.mark.anyio
445+
async def test_set_progress(redis_url: str) -> None:
446+
"""
447+
Test that set_progress/get_progress works.
448+
449+
:param redis_url: redis URL.
450+
"""
451+
result_backend = RedisAsyncResultBackend( # type: ignore
452+
redis_url=redis_url,
453+
)
454+
task_id = uuid.uuid4().hex
455+
456+
test_progress_1 = TaskProgress(
457+
state=TaskState.STARTED,
458+
meta={"message": "quarter way", "pct": 25},
459+
)
460+
test_progress_2 = TaskProgress(
461+
state=TaskState.STARTED,
462+
meta={"message": "half way", "pct": 50},
463+
)
464+
465+
# Progress starts as None
466+
assert await result_backend.get_progress(task_id=task_id) is None
467+
468+
# Setting the first time persists
469+
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)
470+
471+
fetched_result = await result_backend.get_progress(task_id=task_id)
472+
assert fetched_result == test_progress_1
473+
474+
# Setting the second time replaces the first
475+
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)
476+
477+
fetched_result = await result_backend.get_progress(task_id=task_id)
478+
assert fetched_result == test_progress_2
479+
480+
await result_backend.shutdown()
481+
482+
@pytest.mark.anyio
483+
async def test_set_progress_cluster(redis_cluster_url: str) -> None:
484+
"""
485+
Test that set_progress/get_progress works in cluster mode.
486+
487+
:param redis_url: redis URL.
488+
"""
489+
result_backend = RedisAsyncClusterResultBackend( # type: ignore
490+
redis_url=redis_cluster_url,
491+
)
492+
task_id = uuid.uuid4().hex
493+
494+
test_progress_1 = TaskProgress(
495+
state=TaskState.STARTED,
496+
meta={"message": "quarter way", "pct": 25},
497+
)
498+
test_progress_2 = TaskProgress(
499+
state=TaskState.STARTED,
500+
meta={"message": "half way", "pct": 50},
501+
)
502+
503+
# Progress starts as None
504+
assert await result_backend.get_progress(task_id=task_id) is None
505+
506+
# Setting the first time persists
507+
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)
508+
509+
fetched_result = await result_backend.get_progress(task_id=task_id)
510+
assert fetched_result == test_progress_1
511+
512+
# Setting the second time replaces the first
513+
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)
514+
515+
fetched_result = await result_backend.get_progress(task_id=task_id)
516+
assert fetched_result == test_progress_2
517+
518+
await result_backend.shutdown()

0 commit comments

Comments
 (0)