Skip to content

Commit b762571

Browse files
authored
Merge pull request #20 from yudytskiy/feature-abort-pipeline-propagate
feat: AbortPipeline error propagated to the last step task result of pipeline
2 parents ba056d7 + 9e0285c commit b762571

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

taskiq_pipelines/middleware.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from logging import getLogger
2-
from typing import Any, List
2+
from typing import Any, List, Optional
33

44
import pydantic
55
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult
@@ -108,23 +108,28 @@ async def on_error(
108108
return
109109
if current_step_num == len(steps) - 1:
110110
return
111-
await self.fail_pipeline(steps[-1].task_id)
111+
await self.fail_pipeline(steps[-1].task_id, result.error)
112112

113-
async def fail_pipeline(self, last_task_id: str) -> None:
113+
async def fail_pipeline(
114+
self,
115+
last_task_id: str,
116+
abort: Optional[BaseException] = None,
117+
) -> None:
114118
"""
115119
This function aborts pipeline.
116120
117121
This is done by setting error result for
118122
the last task in the pipeline.
119123
120124
:param last_task_id: id of the last task.
125+
:param abort: caught earlier exception or default
121126
"""
122127
await self.broker.result_backend.set_result(
123128
last_task_id,
124129
TaskiqResult(
125130
is_err=True,
126131
return_value=None, # type: ignore
127-
error=AbortPipeline("Execution aborted."),
132+
error=abort or AbortPipeline("Execution aborted."),
128133
execution_time=0,
129134
log="Error found while executing pipeline.",
130135
),

tests/test_steps.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from taskiq import InMemoryBroker
55

6-
from taskiq_pipelines import Pipeline, PipelineMiddleware
6+
from taskiq_pipelines import AbortPipeline, Pipeline, PipelineMiddleware
77

88

99
@pytest.mark.anyio
@@ -42,3 +42,32 @@ def double(i: int) -> int:
4242
sent = await pipe.kiq(4)
4343
res = await sent.wait_result()
4444
assert res.return_value == list(map(double, ranger(4)))
45+
46+
47+
@pytest.mark.anyio
48+
async def test_abort_pipeline() -> None:
49+
"""Test AbortPipeline."""
50+
broker = InMemoryBroker().with_middlewares(PipelineMiddleware())
51+
text = "task was aborted"
52+
53+
@broker.task
54+
def normal_task(i: bool) -> bool:
55+
return i
56+
57+
@broker.task
58+
def aborting_task(i: int) -> bool:
59+
if i:
60+
raise AbortPipeline(text)
61+
return True
62+
63+
pipe = Pipeline(broker, aborting_task).call_next(normal_task)
64+
sent = await pipe.kiq(0)
65+
res = await sent.wait_result()
66+
assert res.is_err is False
67+
assert res.return_value is True
68+
assert res.error is None
69+
sent = await pipe.kiq(1)
70+
res = await sent.wait_result()
71+
assert res.is_err is True
72+
assert res.return_value is None
73+
assert res.error.args[0] == text

0 commit comments

Comments
 (0)