Skip to content

Commit 7825e8a

Browse files
committed
Added automatic worker restarts.
Signed-off-by: Pavel Kirilin <[email protected]>
1 parent 9004b02 commit 7825e8a

File tree

4 files changed

+152
-17
lines changed

4 files changed

+152
-17
lines changed

taskiq/__main__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
import asyncio
2-
31
from taskiq.cli.args import TaskiqArgs
42
from taskiq.cli.worker import run_worker
53

64

75
def main() -> None:
86
"""Main entrypoint for CLI."""
97
args = TaskiqArgs.from_cli()
10-
asyncio.run(run_worker(args))
8+
run_worker(args)
119

1210

1311
if __name__ == "__main__":

taskiq/cli/args.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import enum
22
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
33
from dataclasses import dataclass
4+
from typing import List, Optional
45

56

67
class LogLevel(str, enum.Enum): # noqa: WPS600
@@ -24,13 +25,15 @@ class TaskiqArgs:
2425
log_level: str
2526
workers: int
2627
log_collector_format: str
28+
max_threadpool_threads: int
2729
no_parse: bool
2830

2931
@classmethod
30-
def from_cli(cls) -> "TaskiqArgs":
32+
def from_cli(cls, args: Optional[List[str]] = None) -> "TaskiqArgs": # noqa: WPS213
3133
"""
3234
Construct TaskiqArgs instanc from CLI arguments.
3335
36+
:param args: list of args as for cli.
3437
:return: TaskiqArgs instance.
3538
"""
3639
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
@@ -98,6 +101,14 @@ def from_cli(cls) -> "TaskiqArgs":
98101
" with pydantic."
99102
),
100103
)
104+
parser.add_argument(
105+
"--max-threadpool-threads",
106+
type=int,
107+
help="Maximum number of threads for executing sync functions.",
108+
)
101109

102-
namespace = parser.parse_args()
110+
if args is None:
111+
namespace = parser.parse_args(args)
112+
else:
113+
namespace = parser.parse_args()
103114
return TaskiqArgs(**namespace.__dict__)

taskiq/cli/task_runner.py renamed to taskiq/cli/async_task_runner.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import asyncio
22
import inspect
33
import io
4+
import signal
5+
import sys
6+
from concurrent.futures import Executor, ThreadPoolExecutor
47
from logging import getLogger
58
from time import time
6-
from typing import Any, Callable, Dict, Optional
9+
from typing import Any, Callable, Dict, NoReturn, Optional
710

811
from pydantic import parse_obj_as
912

@@ -99,6 +102,7 @@ async def run_task( # noqa: WPS210
99102
signature: Optional[inspect.Signature],
100103
message: TaskiqMessage,
101104
cli_args: TaskiqArgs,
105+
executor: Optional[Executor] = None,
102106
) -> TaskiqResult[Any]:
103107
"""
104108
This function actually executes functions.
@@ -118,6 +122,7 @@ async def run_task( # noqa: WPS210
118122
:param signature: signature of an original function.
119123
:param message: received message.
120124
:param cli_args: CLI arguments for worker.
125+
:param executor: executor to run sync tasks.
121126
:return: result of execution.
122127
"""
123128
loop = asyncio.get_running_loop()
@@ -133,7 +138,7 @@ async def run_task( # noqa: WPS210
133138
returned = await target(*message.args, **message.kwargs)
134139
else:
135140
returned = await loop.run_in_executor(
136-
None,
141+
executor,
137142
run_sync,
138143
target,
139144
message,
@@ -156,7 +161,64 @@ async def run_task( # noqa: WPS210
156161
)
157162

158163

159-
async def async_listen_messages( # noqa: C901, WPS210
164+
def exit_process(task: asyncio.Task[Any]) -> NoReturn:
165+
"""
166+
This function exits from the current process.
167+
168+
It receives asyncio Task of broker.shutdown().
169+
We check if there were an exception or returned value.
170+
171+
If the function raised an exception, we print it with stack trace.
172+
If it returned a value, we log it.
173+
174+
After this, we cancel all current tasks in the loop
175+
and exits.
176+
177+
:param task: broker.shutdown task.
178+
"""
179+
exitcode = 0
180+
try:
181+
result = task.result()
182+
if result is not None:
183+
logger.info("Broker returned value on shutdown: '%s'" % str(result))
184+
except Exception as exc:
185+
logger.warning("Exception was found while shutting down!")
186+
logger.warning(exc, exc_info=True)
187+
exitcode = 1
188+
189+
loop = asyncio.get_event_loop()
190+
for running_task in asyncio.all_tasks(loop):
191+
running_task.cancel()
192+
193+
logger.info("Killing worker process.")
194+
sys.exit(exitcode)
195+
196+
197+
def signal_handler(broker: AsyncBroker) -> None:
198+
"""
199+
Exit signal handler.
200+
201+
This signal handler
202+
calls _close_broker and after
203+
the task is done it exits.
204+
205+
:param broker: current broker.
206+
"""
207+
if getattr(broker, "_is_shutting_down", False):
208+
# We're already shutting down the broker.
209+
return
210+
211+
# We set this flag to not call this method twice.
212+
# Since we add an asynchronous task in loop
213+
# It can wait for execution for some time.
214+
# We want to execute shutdown only once. Otherwise
215+
# it would give us Undefined Behaviour.
216+
broker._is_shutting_down = True # type: ignore # noqa: WPS437
217+
task = asyncio.create_task(broker.shutdown())
218+
task.add_done_callback(exit_process)
219+
220+
221+
async def async_listen_messages( # noqa: C901, WPS210, WPS213
160222
broker: AsyncBroker,
161223
cli_args: TaskiqArgs,
162224
) -> None:
@@ -169,8 +231,23 @@ async def async_listen_messages( # noqa: C901, WPS210
169231
:param broker: broker to listen to.
170232
:param cli_args: CLI arguments for worker.
171233
"""
234+
loop = asyncio.get_event_loop()
235+
loop.add_signal_handler(
236+
signal.SIGTERM,
237+
signal_handler,
238+
broker,
239+
)
240+
loop.add_signal_handler(
241+
signal.SIGINT,
242+
signal_handler,
243+
broker,
244+
)
245+
172246
logger.info("Runing startup event.")
173247
await broker.startup()
248+
executor = ThreadPoolExecutor(
249+
max_workers=cli_args.max_threadpool_threads,
250+
)
174251
logger.info("Listening started.")
175252
task_registry: Dict[str, Callable[..., Any]] = {}
176253
task_signatures: Dict[str, inspect.Signature] = {}
@@ -197,6 +274,7 @@ async def async_listen_messages( # noqa: C901, WPS210
197274
task_signatures.get(message.task_name),
198275
message,
199276
cli_args,
277+
executor,
200278
)
201279
try:
202280
await broker.result_backend.set_result(message.task_id, result)

taskiq/cli/worker.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import asyncio
2+
import signal
23
from importlib import import_module
34
from logging import basicConfig, getLevelName, getLogger
45
from multiprocessing import Process
56
from pathlib import Path
6-
from typing import Any
7+
from time import sleep
8+
from typing import Any, List
79

810
from taskiq.abc.broker import AsyncBroker
911
from taskiq.cli.args import TaskiqArgs
10-
from taskiq.cli.task_runner import async_listen_messages
12+
from taskiq.cli.async_task_runner import async_listen_messages
1113

1214
logger = getLogger("taskiq.worker")
1315

@@ -86,7 +88,7 @@ def start_listen(args: TaskiqArgs) -> None:
8688
raise ValueError("Unknown broker type. Please use AsyncBroker instance.")
8789

8890

89-
async def run_worker(args: TaskiqArgs) -> None:
91+
def run_worker(args: TaskiqArgs) -> None: # noqa: C901, WPS210, WPS213
9092
"""
9193
This function starts worker processes.
9294
@@ -100,22 +102,68 @@ async def run_worker(args: TaskiqArgs) -> None:
100102
format=("[%(asctime)s][%(levelname)-7s][%(processName)s] %(message)s"),
101103
)
102104
logger.info("Starting %s worker processes." % args.workers)
103-
worker_processes = []
104-
for worker_num in range(args.workers):
105+
worker_processes: List[Process] = []
106+
for process in range(args.workers):
105107
work_proc = Process(
106108
target=start_listen,
107109
kwargs={"args": args},
108-
name=f"worker-{worker_num}",
110+
name=f"worker-{process}",
109111
)
110112
work_proc.start()
111113
logger.debug(
112114
"Started process worker-%d with pid %s "
113115
% (
114-
worker_num,
116+
process,
115117
work_proc.pid,
116118
),
117119
)
118120
worker_processes.append(work_proc)
119121

120-
for wp in worker_processes:
121-
wp.join()
122+
# This flag signalizes that we do need to restart processes.
123+
do_restarts = True
124+
125+
def signal_handler(_signal: int, _frame: Any) -> None:
126+
"""
127+
This handler is used only by main process.
128+
129+
If the OS sent you SIGINT or SIGTERM,
130+
we should kill all spawned processes.
131+
132+
:param _signal: incoming signal.
133+
:param _frame: current execution frame.
134+
"""
135+
nonlocal do_restarts # noqa: WPS420
136+
nonlocal worker_processes # noqa: WPS420
137+
138+
do_restarts = False # noqa: WPS442
139+
for process in worker_processes: # noqa: WPS442
140+
# This is how we send SIGTERM to child
141+
# processes.
142+
process.terminate()
143+
process.join()
144+
145+
signal.signal(signal.SIGINT, signal_handler)
146+
signal.signal(signal.SIGTERM, signal_handler)
147+
148+
while worker_processes and do_restarts:
149+
# List of processes to remove.
150+
sleep(1)
151+
process_to_remove = []
152+
for worker_id, worker in enumerate(worker_processes):
153+
if worker.is_alive():
154+
continue
155+
if worker.exitcode is not None and worker.exitcode > 0 and do_restarts:
156+
logger.info("Trying to restart the worker-%s" % worker_id)
157+
worker_processes[worker_id] = Process(
158+
target=start_listen,
159+
kwargs={"args": args},
160+
name=f"worker-{worker_id}",
161+
)
162+
worker_processes[worker_id].start()
163+
else:
164+
logger.info("Worker-%s has finished." % worker_id)
165+
worker.join()
166+
process_to_remove.append(worker)
167+
168+
for dead_process in process_to_remove:
169+
worker_processes.remove(dead_process)

0 commit comments

Comments
 (0)