Skip to content

Commit cdb431b

Browse files
Sobes76russ3riusAnton
authored
feat: set/get progress (#130)
Co-authored-by: Pavel Kirilin <[email protected]> Co-authored-by: Anton <[email protected]>
1 parent 20f92b0 commit cdb431b

File tree

6 files changed

+280
-4
lines changed

6 files changed

+280
-4
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
strategy:
3838
matrix:
3939
py_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
40-
pydantic_ver: ["<2", ">=2,<3"]
40+
pydantic_ver: ["<2", ">=2.5,<3"]
4141
os: [ubuntu-latest, windows-latest]
4242
runs-on: "${{ matrix.os }}"
4343
steps:

taskiq/abc/result_backend.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from abc import ABC, abstractmethod
2-
from typing import Generic, TypeVar
2+
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
33

44
from taskiq.result import TaskiqResult
55

6+
if TYPE_CHECKING: # pragma: no cover
7+
from taskiq.depends.progress_tracker import TaskProgress
8+
9+
610
_ReturnType = TypeVar("_ReturnType")
711

812

@@ -50,3 +54,25 @@ async def get_result(
5054
:param with_logs: if True it will download task's logs.
5155
:return: task's return value.
5256
"""
57+
58+
async def set_progress(
59+
self,
60+
task_id: str,
61+
progress: "TaskProgress[Any]",
62+
) -> None:
63+
"""
64+
Saves progress.
65+
66+
:param task_id: task's id.
67+
:param progress: progress of execution.
68+
"""
69+
70+
async def get_progress(
71+
self,
72+
task_id: str,
73+
) -> "Optional[TaskProgress[Any]]":
74+
"""
75+
Gets progress.
76+
77+
:param task_id: task's id.
78+
"""

taskiq/brokers/inmemory_broker.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import asyncio
22
from collections import OrderedDict
33
from concurrent.futures import ThreadPoolExecutor
4-
from typing import Any, AsyncGenerator, Set, TypeVar
4+
from typing import Any, AsyncGenerator, Optional, Set, TypeVar
55

66
from taskiq.abc.broker import AsyncBroker
77
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
8+
from taskiq.depends.progress_tracker import TaskProgress
89
from taskiq.events import TaskiqEvents
910
from taskiq.exceptions import TaskiqError
1011
from taskiq.message import BrokerMessage
@@ -27,6 +28,7 @@ class InmemoryResultBackend(AsyncResultBackend[_ReturnType]):
2728
def __init__(self, max_stored_results: int = 100) -> None:
2829
self.max_stored_results = max_stored_results
2930
self.results: OrderedDict[str, TaskiqResult[_ReturnType]] = OrderedDict()
31+
self.progress: OrderedDict[str, TaskProgress[Any]] = OrderedDict()
3032

3133
async def set_result(self, task_id: str, result: TaskiqResult[_ReturnType]) -> None:
3234
"""
@@ -79,6 +81,37 @@ async def get_result(
7981
"""
8082
return self.results[task_id]
8183

84+
async def set_progress(
85+
self,
86+
task_id: str,
87+
progress: TaskProgress[Any],
88+
) -> None:
89+
"""
90+
Set progress of task exection.
91+
92+
:param task_id: task id
93+
:param progress: task execution progress
94+
"""
95+
if (
96+
self.max_stored_results != -1
97+
and len(self.progress) >= self.max_stored_results
98+
):
99+
self.progress.popitem(last=False)
100+
101+
self.progress[task_id] = progress
102+
103+
async def get_progress(
104+
self,
105+
task_id: str,
106+
) -> Optional[TaskProgress[Any]]:
107+
"""
108+
Get progress of task execution.
109+
110+
:param task_id: task id
111+
:return: progress or None
112+
"""
113+
return self.progress.get(task_id)
114+
82115

83116
class InMemoryBroker(AsyncBroker):
84117
"""

taskiq/depends/progress_tracker.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import enum
2+
from typing import Generic, Optional, Union
3+
4+
from taskiq_dependencies import Depends
5+
from typing_extensions import TypeVar
6+
7+
from taskiq.compat import IS_PYDANTIC2
8+
from taskiq.context import Context
9+
10+
if IS_PYDANTIC2:
11+
from pydantic import BaseModel as GenericModel
12+
else:
13+
from pydantic.generics import GenericModel # type: ignore[no-redef]
14+
15+
16+
_ProgressType = TypeVar("_ProgressType")
17+
18+
19+
class TaskState(str, enum.Enum):
20+
"""State of task execution."""
21+
22+
STARTED = "STARTED"
23+
FAILURE = "FAILURE"
24+
SUCCESS = "SUCCESS"
25+
RETRY = "RETRY"
26+
27+
28+
class TaskProgress(GenericModel, Generic[_ProgressType]):
29+
"""Progress of task execution."""
30+
31+
state: Union[TaskState, str]
32+
meta: Optional[_ProgressType]
33+
34+
35+
class ProgressTracker(Generic[_ProgressType]):
36+
"""Task's dependency to set progress."""
37+
38+
def __init__(
39+
self,
40+
context: Context = Depends(),
41+
) -> None:
42+
self.context = context
43+
44+
async def set_progress(
45+
self,
46+
state: Union[TaskState, str],
47+
meta: Optional[_ProgressType] = None,
48+
) -> None:
49+
"""Set progress.
50+
51+
:param state: TaskState or str
52+
:param meta: progress data
53+
"""
54+
if meta is None:
55+
progress = await self.get_progress()
56+
meta = progress.meta if progress else None
57+
58+
progress = TaskProgress(
59+
state=state,
60+
meta=meta,
61+
)
62+
63+
await self.context.broker.result_backend.set_progress(
64+
self.context.message.task_id,
65+
progress,
66+
)
67+
68+
async def get_progress(self) -> Optional[TaskProgress[_ProgressType]]:
69+
"""Get progress."""
70+
return await self.context.broker.result_backend.get_progress(
71+
self.context.message.task_id,
72+
)

taskiq/task.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
22
from abc import ABC, abstractmethod
33
from time import time
4-
from typing import TYPE_CHECKING, Any, Coroutine, Generic, TypeVar, Union
4+
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Union
5+
6+
from typing_extensions import TypeVar
57

68
from taskiq.exceptions import (
79
ResultGetError,
@@ -11,6 +13,7 @@
1113

1214
if TYPE_CHECKING: # pragma: no cover
1315
from taskiq.abc.result_backend import AsyncResultBackend
16+
from taskiq.depends.progress_tracker import TaskProgress
1417
from taskiq.result import TaskiqResult
1518

1619
_ReturnType = TypeVar("_ReturnType")
@@ -65,6 +68,19 @@ def wait_result(
6568
:return: TaskiqResult.
6669
"""
6770

71+
@abstractmethod
72+
def get_progress(
73+
self,
74+
) -> Union[
75+
"Optional[TaskProgress[Any]]",
76+
Coroutine[Any, Any, "Optional[TaskProgress[Any]]"],
77+
]:
78+
"""
79+
Get task progress.
80+
81+
:return: task's progress.
82+
"""
83+
6884

6985
class AsyncTaskiqTask(_Task[_ReturnType]):
7086
"""AsyncTask for AsyncResultBackend."""
@@ -137,3 +153,11 @@ async def wait_result(
137153
if 0 < timeout < time() - start_time:
138154
raise TaskiqResultTimeoutError
139155
return await self.get_result(with_logs=with_logs)
156+
157+
async def get_progress(self) -> "Optional[TaskProgress[Any]]":
158+
"""
159+
Get task progress.
160+
161+
:return: task's progress.
162+
"""
163+
return await self.result_backend.get_progress(self.task_id)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from typing import Any, Dict, Optional
3+
4+
import pytest
5+
from pydantic import ValidationError
6+
7+
from taskiq import (
8+
AsyncTaskiqDecoratedTask,
9+
InMemoryBroker,
10+
TaskiqDepends,
11+
TaskiqMessage,
12+
)
13+
from taskiq.abc import AsyncBroker
14+
from taskiq.depends.progress_tracker import ProgressTracker, TaskState
15+
from taskiq.receiver import Receiver
16+
17+
18+
def get_receiver(
19+
broker: Optional[AsyncBroker] = None,
20+
no_parse: bool = False,
21+
max_async_tasks: Optional[int] = None,
22+
) -> Receiver:
23+
"""
24+
Returns receiver with custom broker and args.
25+
26+
:param broker: broker, defaults to None
27+
:param no_parse: parameter to taskiq_args, defaults to False
28+
:param cli_args: Taskiq worker CLI arguments.
29+
:return: new receiver.
30+
"""
31+
if broker is None:
32+
broker = InMemoryBroker()
33+
return Receiver(
34+
broker,
35+
executor=ThreadPoolExecutor(max_workers=10),
36+
validate_params=not no_parse,
37+
max_async_tasks=max_async_tasks,
38+
)
39+
40+
41+
def get_message(
42+
task: AsyncTaskiqDecoratedTask[Any, Any],
43+
task_id: Optional[str] = None,
44+
*args: Any,
45+
labels: Optional[Dict[str, str]] = None,
46+
**kwargs: Dict[str, Any],
47+
) -> TaskiqMessage:
48+
if labels is None:
49+
labels = {}
50+
return TaskiqMessage(
51+
task_id=task_id or task.broker.id_generator(),
52+
task_name=task.task_name,
53+
labels=labels,
54+
args=list(args),
55+
kwargs=kwargs,
56+
)
57+
58+
59+
@pytest.mark.anyio
60+
@pytest.mark.parametrize(
61+
"state,meta",
62+
[
63+
(TaskState.STARTED, "hello world!"),
64+
("retry", "retry error!"),
65+
("custom state", {"Complex": "Value"}),
66+
],
67+
)
68+
async def test_progress_tracker_ctx_raw(state: Any, meta: Any) -> None:
69+
broker = InMemoryBroker()
70+
71+
@broker.task
72+
async def test_func(tes_val: ProgressTracker[Any] = TaskiqDepends()) -> None:
73+
await tes_val.set_progress(state, meta)
74+
75+
kicker = await test_func.kiq()
76+
result = await kicker.wait_result()
77+
78+
assert not result.is_err
79+
progress = await broker.result_backend.get_progress(kicker.task_id)
80+
assert progress is not None
81+
assert progress.meta == meta
82+
assert progress.state == state
83+
84+
85+
@pytest.mark.anyio
86+
async def test_progress_tracker_ctx_none() -> None:
87+
broker = InMemoryBroker()
88+
89+
@broker.task
90+
async def test_func() -> None:
91+
pass
92+
93+
kicker = await test_func.kiq()
94+
result = await kicker.wait_result()
95+
96+
assert not result.is_err
97+
progress = await broker.result_backend.get_progress(kicker.task_id)
98+
assert progress is None
99+
100+
101+
@pytest.mark.anyio
102+
@pytest.mark.parametrize(
103+
"state,meta",
104+
[
105+
(("state", "error"), 1),
106+
],
107+
)
108+
async def test_progress_tracker_validation_error(state: Any, meta: Any) -> None:
109+
broker = InMemoryBroker()
110+
111+
@broker.task
112+
async def test_func(progress: ProgressTracker[int] = TaskiqDepends()) -> None:
113+
await progress.set_progress(state, meta) # type: ignore
114+
115+
kicker = await test_func.kiq()
116+
result = await kicker.wait_result()
117+
with pytest.raises(ValidationError):
118+
result.raise_for_error()
119+
120+
progress = await broker.result_backend.get_progress(kicker.task_id)
121+
assert progress is None

0 commit comments

Comments
 (0)