Skip to content

Commit 9721cfb

Browse files
committed
Added filter step.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 3119f5f commit 9721cfb

File tree

7 files changed

+274
-5
lines changed

7 files changed

+274
-5
lines changed

taskiq_pipelines/abc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ async def act(
4343
self,
4444
broker: AsyncBroker,
4545
step_number: int,
46+
parent_task_id: str,
4647
task_id: str,
4748
pipe_data: str,
4849
result: "TaskiqResult[Any]",
@@ -57,6 +58,7 @@ async def act(
5758
5859
:param broker: current broker.
5960
:param step_number: current step number.
61+
:param parent_task_id: current task id.
6062
:param task_id: task_id to use.
6163
:param pipe_data: serialized pipeline must be in labels.
6264
:param result: result of a previous task.

taskiq_pipelines/middleware.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ async def post_execute( # noqa: C901, WPS212
6363
await next_step.act(
6464
broker=self.broker,
6565
step_number=current_step_num + 1,
66+
parent_task_id=message.task_id,
6667
task_id=next_step_data.task_id,
6768
pipe_data=pipeline_data,
6869
result=result,

taskiq_pipelines/pipeliner.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing_extensions import ParamSpec
99

1010
from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
11-
from taskiq_pipelines.steps import MapperStep, SequentialStep, parse_step
11+
from taskiq_pipelines.steps import FilterStep, MapperStep, SequentialStep, parse_step
1212

1313
_ReturnType = TypeVar("_ReturnType")
1414
_FuncParams = ParamSpec("_FuncParams")
@@ -182,6 +182,78 @@ def map(
182182
)
183183
return self
184184

185+
@overload
186+
def filter(
187+
self: "Pipeline[_FuncParams, _ReturnType]",
188+
task: Union[
189+
AsyncKicker[Any, Coroutine[Any, Any, bool]],
190+
AsyncTaskiqDecoratedTask[Any, Coroutine[Any, Any, bool]],
191+
],
192+
param_name: Optional[str] = None,
193+
skip_errors: bool = False,
194+
check_interval: float = 0.5,
195+
**additional_kwargs: Any,
196+
) -> "Pipeline[_FuncParams, _ReturnType]":
197+
...
198+
199+
@overload
200+
def filter(
201+
self: "Pipeline[_FuncParams, _ReturnType]",
202+
task: Union[
203+
AsyncKicker[Any, bool],
204+
AsyncTaskiqDecoratedTask[Any, bool],
205+
],
206+
param_name: Optional[str] = None,
207+
skip_errors: bool = False,
208+
check_interval: float = 0.5,
209+
**additional_kwargs: Any,
210+
) -> "Pipeline[_FuncParams, _ReturnType]":
211+
...
212+
213+
def filter(
214+
self,
215+
task: Union[
216+
AsyncKicker[Any, Any],
217+
AsyncTaskiqDecoratedTask[Any, Any],
218+
],
219+
param_name: Optional[str] = None,
220+
skip_errors: bool = False,
221+
check_interval: float = 0.5,
222+
**additional_kwargs: Any,
223+
) -> Any:
224+
"""
225+
Add filter step.
226+
227+
This step is executed on a list of items,
228+
like map.
229+
230+
It runs many small subtasks for each item
231+
in sequence and if task returns true,
232+
the result is added to the final list.
233+
234+
:param task: task to execute on every item.
235+
:param param_name: parameter name to pass item into, defaults to None
236+
:param skip_errors: skip errors if any, defaults to False
237+
:param check_interval: how often the result of all subtasks is checked,
238+
defaults to 0.5
239+
:param additional_kwargs: additional function's kwargs.
240+
:return: pipeline with filtering step.
241+
"""
242+
self.steps.append(
243+
DumpedStep(
244+
step_type=FilterStep.step_name,
245+
step_data=FilterStep.from_task(
246+
task=task,
247+
param_name=param_name,
248+
skip_errors=skip_errors,
249+
check_interval=check_interval,
250+
**additional_kwargs,
251+
).dumps(),
252+
task_id="",
253+
),
254+
)
255+
return self
256+
185257
def dumps(self) -> str:
186258
"""
187259
Dumps current pipeline as string.

taskiq_pipelines/steps/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from logging import getLogger
33

44
from taskiq_pipelines.abc import AbstractStep
5+
from taskiq_pipelines.steps.filter import FilterStep
56
from taskiq_pipelines.steps.mapper import MapperStep
67
from taskiq_pipelines.steps.sequential import SequentialStep
78

@@ -19,4 +20,5 @@ def parse_step(step_type: str, step_data: str) -> AbstractStep:
1920
__all__ = [
2021
"MapperStep",
2122
"SequentialStep",
23+
"FilterStep",
2224
]

taskiq_pipelines/steps/filter.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import asyncio
2+
from typing import Any, Dict, Iterable, List, Optional, Union
3+
4+
import pydantic
5+
from taskiq import AsyncBroker, TaskiqError, TaskiqResult
6+
from taskiq.brokers.shared_broker import async_shared_broker
7+
from taskiq.context import Context, default_context
8+
from taskiq.decor import AsyncTaskiqDecoratedTask
9+
from taskiq.kicker import AsyncKicker
10+
11+
from taskiq_pipelines.abc import AbstractStep
12+
from taskiq_pipelines.constants import CURRENT_STEP, PIPELINE_DATA
13+
from taskiq_pipelines.exceptions import AbortPipeline
14+
15+
16+
@async_shared_broker.task(task_name="taskiq_pipelines.shared.filter_tasks")
17+
async def filter_tasks( # noqa: C901, WPS210, WPS231
18+
task_ids: List[str],
19+
parent_task_id: str,
20+
check_interval: float,
21+
context: Context = default_context,
22+
skip_errors: bool = False,
23+
) -> List[Any]:
24+
"""
25+
Filter resulted tasks.
26+
27+
It takes list of task ids,
28+
and parent task id.
29+
30+
After all subtasks are completed it gets
31+
result of a parent task, and
32+
if subtask's result of execution can be
33+
converted to True, the item from the original
34+
tasks is added to the resulting array.
35+
36+
:param task_ids: ordered list of task ids.
37+
:param parent_task_id: task id of a parent task.
38+
:param check_interval: how often checks are performed.
39+
:param context: context of the execution, defaults to default_context
40+
:param skip_errors: skip errors of subtasks, defaults to False
41+
:raises TaskiqError: if any subtask has returned error.
42+
:return: fitlered results.
43+
"""
44+
ordered_ids = task_ids[:]
45+
tasks_set = set(task_ids)
46+
while tasks_set:
47+
for task_id in task_ids: # noqa: WPS327
48+
if await context.broker.result_backend.is_result_ready(task_id):
49+
try:
50+
tasks_set.remove(task_id)
51+
except LookupError:
52+
continue
53+
await asyncio.sleep(check_interval)
54+
55+
results = await context.broker.result_backend.get_result(parent_task_id)
56+
filtered_results = []
57+
for task_id, value in zip( # type: ignore # noqa: WPS352, WPS440
58+
ordered_ids,
59+
results.return_value,
60+
):
61+
result = await context.broker.result_backend.get_result(task_id)
62+
if result.is_err:
63+
if skip_errors:
64+
continue
65+
raise TaskiqError(f"Task {task_id} returned error. Filtering failed.")
66+
if result.return_value:
67+
filtered_results.append(value)
68+
return filtered_results
69+
70+
71+
class FilterStep(pydantic.BaseModel, AbstractStep, step_name="filter"):
72+
"""Task to filter results."""
73+
74+
task_name: str
75+
labels: Dict[str, str]
76+
param_name: Optional[str]
77+
additional_kwargs: Dict[str, Any]
78+
skip_errors: bool
79+
check_interval: float
80+
81+
def dumps(self) -> str:
82+
"""
83+
Dumps step as string.
84+
85+
:return: returns json.
86+
"""
87+
return self.json()
88+
89+
@classmethod
90+
def loads(cls, data: str) -> "FilterStep":
91+
"""
92+
Parses mapper step from string.
93+
94+
:param data: dumped data.
95+
:return: parsed step.
96+
"""
97+
return pydantic.parse_raw_as(FilterStep, data)
98+
99+
async def act(
100+
self,
101+
broker: AsyncBroker,
102+
step_number: int,
103+
parent_task_id: str,
104+
task_id: str,
105+
pipe_data: str,
106+
result: "TaskiqResult[Any]",
107+
) -> None:
108+
"""
109+
Run filter action.
110+
111+
This function creates many small filter steps,
112+
and then collects all results in one big filtered array,
113+
using 'filter_tasks' shared task.
114+
115+
:param broker: current broker.
116+
:param step_number: current step number.
117+
:param parent_task_id: task_id of the previous step.
118+
:param task_id: task_id to use in this step.
119+
:param pipe_data: serialized pipeline.
120+
:param result: result of the previous task.
121+
:raises AbortPipeline: if result is not iterable.
122+
"""
123+
if not isinstance(result.return_value, Iterable):
124+
raise AbortPipeline("Result of the previous task is not iterable.")
125+
sub_task_ids = []
126+
for item in result.return_value:
127+
kicker: "AsyncKicker[Any, Any]" = AsyncKicker(
128+
task_name=self.task_name,
129+
broker=broker,
130+
labels=self.labels,
131+
)
132+
if self.param_name:
133+
self.additional_kwargs[self.param_name] = item
134+
task = await kicker.kiq(**self.additional_kwargs)
135+
else:
136+
task = await kicker.kiq(item, **self.additional_kwargs)
137+
sub_task_ids.append(task.task_id)
138+
139+
await filter_tasks.kicker().with_task_id(task_id).with_broker(
140+
broker,
141+
).with_labels(
142+
**{CURRENT_STEP: step_number, PIPELINE_DATA: pipe_data}, # type: ignore
143+
).kiq(
144+
sub_task_ids,
145+
parent_task_id,
146+
check_interval=self.check_interval,
147+
skip_errors=self.skip_errors,
148+
)
149+
150+
@classmethod
151+
def from_task(
152+
cls,
153+
task: Union[
154+
AsyncKicker[Any, Any],
155+
AsyncTaskiqDecoratedTask[Any, Any],
156+
],
157+
param_name: Optional[str],
158+
skip_errors: bool,
159+
check_interval: float,
160+
**additional_kwargs: Any,
161+
) -> "FilterStep":
162+
"""
163+
Create new filter step from task.
164+
165+
:param task: task to execute.
166+
:param param_name: parameter name.
167+
:param skip_errors: don't fail collector
168+
task on errors.
169+
:param check_interval: how often tasks are checked.
170+
:param additional_kwargs: additional function's kwargs.
171+
:return: new mapper step.
172+
"""
173+
if isinstance(task, AsyncTaskiqDecoratedTask):
174+
kicker = task.kicker()
175+
else:
176+
kicker = task
177+
message = kicker._prepare_message() # noqa: WPS437
178+
return FilterStep(
179+
task_name=message.task_name,
180+
labels=message.labels,
181+
param_name=param_name,
182+
additional_kwargs=additional_kwargs,
183+
skip_errors=skip_errors,
184+
check_interval=check_interval,
185+
)

taskiq_pipelines/steps/mapper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from taskiq_pipelines.exceptions import AbortPipeline
1818

1919

20-
@async_shared_broker.task(task_name="taskiq_pipelines.wait_tasks")
21-
async def wait_tasks( # noqa: C901
20+
@async_shared_broker.task(task_name="taskiq_pipelines.shared.wait_tasks")
21+
async def wait_tasks( # noqa: C901, WPS231
2222
task_ids: List[str],
2323
check_interval: float,
2424
context: Context = default_context,
@@ -44,9 +44,12 @@ async def wait_tasks( # noqa: C901
4444
ordered_ids = task_ids[:]
4545
tasks_set = set(task_ids)
4646
while tasks_set:
47-
for task_id in task_ids:
47+
for task_id in task_ids: # noqa: WPS327
4848
if await context.broker.result_backend.is_result_ready(task_id):
49-
tasks_set.remove(task_id)
49+
try:
50+
tasks_set.remove(task_id)
51+
except LookupError:
52+
continue
5053
await asyncio.sleep(check_interval)
5154

5255
results = []
@@ -92,6 +95,7 @@ async def act(
9295
self,
9396
broker: AsyncBroker,
9497
step_number: int,
98+
parent_task_id: str,
9599
task_id: str,
96100
pipe_data: str,
97101
result: "TaskiqResult[Any]",
@@ -109,6 +113,7 @@ async def act(
109113
:param broker: current broker.
110114
:param step_number: current step number.
111115
:param task_id: waiter task_id.
116+
:param parent_task_id: task_id of the previous step.
112117
:param pipe_data: serialized pipeline.
113118
:param result: result of the previous task.
114119
:raises AbortPipeline: if the result of the

taskiq_pipelines/steps/sequential.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ async def act(
4444
self,
4545
broker: AsyncBroker,
4646
step_number: int,
47+
parent_task_id: str,
4748
task_id: str,
4849
pipe_data: str,
4950
result: "TaskiqResult[Any]",
@@ -61,6 +62,7 @@ async def act(
6162
6263
:param broker: current broker.
6364
:param step_number: current step number.
65+
:param parent_task_id: current step's task id.
6466
:param task_id: new task id.
6567
:param pipe_data: serialized pipeline.
6668
:param result: result of the previous task.

0 commit comments

Comments
 (0)