Skip to content

Commit ab3c911

Browse files
authored
Exception propagation (#141)
1 parent b03811f commit ab3c911

File tree

5 files changed

+55
-2
lines changed

5 files changed

+55
-2
lines changed

docs/guide/state-and-deps.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,33 @@ Taskiq have an ability to add new first-level dependencies using brokers.
181181
The AsyncBroker interface has a function called `add_dependency_context` and you can add
182182
more default dependencies to the taskiq. This may be useful for libraries if you want to
183183
add new dependencies to users.
184+
185+
186+
### Exception handling
187+
188+
Dependencies can handle exceptions that happen in tasks. This feature is handy if you want your system to be more atomic.
189+
190+
For example, if you open a database transaction in your dependency and want to commit it only if the function is completed successfully.
191+
192+
```python
193+
async def get_transaction(db_driver: DBDriver = TaskiqDepends(get_driver)) -> AsyncGenerator[Transaction, None]:
194+
trans = db_driver.begin_transaction():
195+
try:
196+
# Here we give transaction to our dependant function.
197+
yield trans
198+
# If exception was found in dependant function,
199+
# we rollback our transaction.
200+
except Exception:
201+
await trans.rollback()
202+
return
203+
# Here we commit if everything is fine.
204+
await trans.commit()
205+
```
206+
207+
If you don't want to propagate exceptions in dependencies, you can add `--no-propagate-errors` option to `worker` command.
208+
209+
```bash
210+
taskiq worker my_file:broker --no-propagate-errors
211+
```
212+
213+
In this case, no exception will ever going to be propagated to any dependency.

taskiq/brokers/inmemory_broker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ class InMemoryBroker(AsyncBroker):
8888
It's useful for local development, if you don't want to setup real broker.
8989
"""
9090

91-
def __init__(
91+
def __init__( # noqa: WPS211
9292
self,
9393
sync_tasks_pool_size: int = 4,
9494
max_stored_results: int = 100,
9595
cast_types: bool = True,
9696
max_async_tasks: int = 30,
97+
propagate_exceptions: bool = True,
9798
) -> None:
9899
super().__init__()
99100
self.result_backend = InmemoryResultBackend(
@@ -105,6 +106,7 @@ def __init__(
105106
executor=self.executor,
106107
validate_params=cast_types,
107108
max_async_tasks=max_async_tasks,
109+
propagate_exceptions=propagate_exceptions,
108110
)
109111
self._running_tasks: "Set[asyncio.Task[Any]]" = set()
110112

taskiq/cli/worker/args.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class WorkerArgs:
4141
receiver: str = "taskiq.receiver:Receiver"
4242
receiver_arg: List[Tuple[str, str]] = field(default_factory=list)
4343
max_prefetch: int = 0
44+
no_propagate_errors: bool = False
4445

4546
@classmethod
4647
def from_cli( # noqa: WPS213
@@ -138,6 +139,16 @@ def from_cli( # noqa: WPS213
138139
" with pydantic."
139140
),
140141
)
142+
parser.add_argument(
143+
"--no-propagate-errors",
144+
action="store_true",
145+
dest="no_propagate_errors",
146+
help=(
147+
"If this parameter is on,"
148+
" all errors that happen in tasks "
149+
" won't be propagated to generator dependencies."
150+
),
151+
)
141152
parser.add_argument(
142153
"--max-threadpool-threads",
143154
type=int,

taskiq/cli/worker/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None:
133133
validate_params=not args.no_parse,
134134
max_async_tasks=args.max_async_tasks,
135135
max_prefetch=args.max_prefetch,
136+
propagate_exceptions=not args.no_propagate_errors,
136137
**receiver_args,
137138
)
138139
loop.run_until_complete(receiver.listen())

taskiq/receiver/receiver.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,15 @@ def __init__( # noqa: WPS211
4646
validate_params: bool = True,
4747
max_async_tasks: "Optional[int]" = None,
4848
max_prefetch: int = 0,
49+
propagate_exceptions: bool = True,
4950
) -> None:
5051
self.broker = broker
5152
self.executor = executor
5253
self.validate_params = validate_params
5354
self.task_signatures: Dict[str, inspect.Signature] = {}
5455
self.task_hints: Dict[str, Dict[str, Any]] = {}
5556
self.dependency_graphs: Dict[str, DependencyGraph] = {}
57+
self.propagate_exceptions = propagate_exceptions
5658
for task in self.broker.available_tasks.values():
5759
self.task_signatures[task.task_name] = inspect.signature(task.original_func)
5860
self.task_hints[task.task_name] = get_type_hints(task.original_func)
@@ -213,7 +215,14 @@ async def run_task( # noqa: C901, WPS210
213215
# Stop the timer.
214216
execution_time = time() - start_time
215217
if dep_ctx:
216-
await dep_ctx.close()
218+
args = (None, None, None)
219+
if found_exception and self.propagate_exceptions:
220+
args = ( # type: ignore
221+
type(found_exception),
222+
found_exception,
223+
found_exception.__traceback__,
224+
)
225+
await dep_ctx.close(*args)
217226

218227
# Assemble result.
219228
result: "TaskiqResult[Any]" = TaskiqResult(

0 commit comments

Comments
 (0)