Skip to content

Commit 0eb74c3

Browse files
author
Юдыцкий Игорь
committed
feat: AbortPipeline error propagated to the last step task result of pipeline
1 parent ba056d7 commit 0eb74c3

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

taskiq_pipelines/middleware.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,23 +108,24 @@ async def on_error(
108108
return
109109
if current_step_num == len(steps) - 1:
110110
return
111-
await self.fail_pipeline(steps[-1].task_id)
111+
await self.fail_pipeline(steps[-1].task_id, result.error)
112112

113-
async def fail_pipeline(self, last_task_id: str) -> None:
113+
async def fail_pipeline(self, last_task_id: str, abort: AbortPipeline | None = None) -> None:
114114
"""
115115
This function aborts pipeline.
116116
117117
This is done by setting error result for
118118
the last task in the pipeline.
119119
120120
:param last_task_id: id of the last task.
121+
:param abort: caught earlier exception or default
121122
"""
122123
await self.broker.result_backend.set_result(
123124
last_task_id,
124125
TaskiqResult(
125126
is_err=True,
126127
return_value=None, # type: ignore
127-
error=AbortPipeline("Execution aborted."),
128+
error=abort or AbortPipeline("Execution aborted."),
128129
execution_time=0,
129130
log="Error found while executing pipeline.",
130131
),

tests/test_steps.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import sys
12
from typing import List
23

34
import pytest
45
from taskiq import InMemoryBroker
56

6-
from taskiq_pipelines import Pipeline, PipelineMiddleware
7+
from taskiq_pipelines import Pipeline, PipelineMiddleware, AbortPipeline
78

89

910
@pytest.mark.anyio
@@ -42,3 +43,32 @@ def double(i: int) -> int:
4243
sent = await pipe.kiq(4)
4344
res = await sent.wait_result()
4445
assert res.return_value == list(map(double, ranger(4)))
46+
47+
48+
@pytest.mark.anyio
49+
async def test_abort_pipeline() -> None:
50+
"""Test AbortPipeline."""
51+
broker = InMemoryBroker().with_middlewares(PipelineMiddleware())
52+
text = 'task was aborted'
53+
54+
@broker.task
55+
def normal_task(i: bool) -> bool:
56+
return i
57+
58+
@broker.task
59+
def aborting_task(i: int) -> bool:
60+
if i:
61+
raise AbortPipeline(text)
62+
return True
63+
64+
pipe = Pipeline(broker, aborting_task).call_next(normal_task)
65+
sent = await pipe.kiq(0)
66+
res = await sent.wait_result()
67+
assert res.is_err is False
68+
assert res.return_value is True
69+
assert res.error is None
70+
sent = await pipe.kiq(1)
71+
res = await sent.wait_result()
72+
assert res.is_err is True
73+
assert res.return_value is None
74+
assert res.error.args[0] == text

0 commit comments

Comments
 (0)