Skip to content

Healthcheck support #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
667 changes: 655 additions & 12 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pytz = "*"
orjson = { version = "^3", optional = true }
msgpack = { version = "^1.0.7", optional = true }
cbor2 = { version = "^5", optional = true }
# For health checks
aiohttp = { version = "^3.8", optional = true }
izulu = "0.50.0"

[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -77,6 +79,7 @@ reload = ["watchdog", "gitignore-parser"]
orjson = ["orjson"]
msgpack = ["msgpack"]
cbor = ["cbor2"]
health = ["aiohttp"]

[tool.poetry.scripts]
taskiq = "taskiq.__main__:main"
Expand Down
25 changes: 25 additions & 0 deletions taskiq/cli/worker/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class WorkerArgs:
hardkill_count: int = 3
use_process_pool: bool = False

# Health check arguments
health_check_enable: bool = False
health_check_port: int = 8081
health_check_timeout: float = 30.0

@classmethod
def from_cli(
cls,
Expand Down Expand Up @@ -255,6 +260,26 @@ def from_cli(
default=None,
help="Maximum number of processes in process pool.",
)
parser.add_argument(
"--health-check-enable",
action="store_true",
dest="health_check_enable",
help="Enable HTTP health check endpoints for Kubernetes probes.",
)
parser.add_argument(
"--health-check-port",
type=int,
dest="health_check_port",
default=8081,
help="Port for health check HTTP server.",
)
parser.add_argument(
"--health-check-timeout",
type=float,
dest="health_check_timeout",
default=30.0,
help="Seconds before worker considered unresponsive.",
)

namespace = parser.parse_args(args)
# If there are any patterns specified, remove default.
Expand Down
176 changes: 157 additions & 19 deletions taskiq/cli/worker/process_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import os
import signal
Expand All @@ -20,6 +21,16 @@

from taskiq.cli.worker.args import WorkerArgs

try:
from taskiq.health.heartbeat import WorkerHeartbeatArray
from taskiq.health.server import HealthCheckServer

health_available = True
except ImportError:
WorkerHeartbeatArray = None # type: ignore[assignment, misc]
HealthCheckServer = None # type: ignore[assignment, misc]
health_available = False

logger = logging.getLogger("taskiq.process-manager")


Expand Down Expand Up @@ -61,13 +72,15 @@ def handle(
workers: List[Process],
args: WorkerArgs,
worker_func: Callable[[WorkerArgs], None],
heartbeat_array: Optional[Any] = None,
) -> None:
"""
This action reloads a single process.

:param workers: known children processes.
:param args: args for new process.
:param worker_func: function that is used to start worker processes.
:param heartbeat_array: optional heartbeat array for health checks.
"""
if self.worker_num < 0 or self.worker_num >= len(workers):
logger.warning("Unknown worker id.")
Expand All @@ -76,18 +89,45 @@ def handle(
try:
worker.terminate()
except ValueError:
logger.debug(f"Process {worker.name} is already terminated.")
logger.debug("Process %s is already terminated.", worker.name)
# Waiting worker shutdown.
worker.join()
event: EventType = Event()
new_process = Process(
target=worker_func,
kwargs={"args": args},
name=f"worker-{self.worker_num}",
daemon=False,
)

# Create wrapper function if health checks enabled
if heartbeat_array is not None:

def make_worker_wrapper(
worker_id: int,
heartbeat_array: Any,
) -> Callable[[], None]:
def _wrapper() -> None:
from taskiq.cli.worker.run import start_listen

start_listen(args, worker_id, heartbeat_array)

return _wrapper

worker_wrapper = make_worker_wrapper(self.worker_num, heartbeat_array)

new_process = Process(
target=worker_wrapper,
name=f"worker-{self.worker_num}",
daemon=False,
)
else:
new_process = Process(
target=worker_func,
kwargs={"args": args},
name=f"worker-{self.worker_num}",
daemon=False,
)
new_process.start()
logger.info(f"Process {new_process.name} restarted with pid {new_process.pid}")
logger.info(
"Process %s restarted with pid %s",
new_process.name,
new_process.pid,
)
workers[self.worker_num] = new_process
_wait_for_worker_startup(new_process, event)

Expand All @@ -98,6 +138,12 @@ class ShutdownAction(ProcessActionBase):


def _wait_for_worker_startup(process: Process, event: EventType) -> None:
"""Wait for worker process to start up.

Args:
process: The worker process to wait for
event: Event that signals worker startup (currently unused)
"""
while process.is_alive():
with suppress(TimeoutError):
event.wait(0.1)
Expand Down Expand Up @@ -138,7 +184,7 @@ def _signal_handler(signum: int, _frame: Any) -> None:
if current_process().name.startswith("worker"):
raise KeyboardInterrupt

logger.debug(f"Got signal {signum}.")
logger.debug("Got signal %d.", signum)
action_queue.put(action_to_send)
logger.warning("Workers are scheduled for shutdown.")

Expand All @@ -163,10 +209,27 @@ def __init__(
self.worker_function = worker_function
self.action_queue: "Queue[ProcessActionBase]" = Queue(-1)
self.args = args

# Initialize heartbeat system if health checks enabled
self.heartbeat_array: Optional[Any] = None
self.health_server: Optional[Any] = None
self._health_server_task: Optional[asyncio.Task[None]] = None

if args.health_check_enable and health_available:
self.heartbeat_array = WorkerHeartbeatArray(args.workers) # type: ignore[misc]
self.health_server = HealthCheckServer( # type: ignore[misc]
port=args.health_check_port,
heartbeat_array=self.heartbeat_array,
timeout=args.health_check_timeout,
)
elif args.health_check_enable and not health_available:
logger.warning(
"Health checks requested but health module not available",
)
if args.reload and observer is not None:
watch_paths = args.reload_dirs if args.reload_dirs else ["."]
for path_to_watch in watch_paths:
logger.debug(f"Watching directory: {path_to_watch}")
logger.debug("Watching directory: %s", path_to_watch)
observer.schedule(
FileWatcher(
callback=schedule_workers_reload,
Expand All @@ -189,17 +252,52 @@ def __init__(

self.workers: List[Process] = []

async def start_health_server(self) -> None:
"""Start health check server in main process."""
if self.health_server:
await self.health_server.start()

async def stop_health_server(self) -> None:
"""Stop health check server."""
if self.health_server:
await self.health_server.stop()

def prepare_workers(self) -> None:
"""Spawn multiple processes."""
events: List[EventType] = []
for process in range(self.args.workers):
event = Event()
work_proc = Process(
target=self.worker_function,
kwargs={"args": self.args},
name=f"worker-{process}",
daemon=False,
)

# Create wrapper function if health checks enabled
if self.heartbeat_array is not None:
# Use a factory function to avoid closure issues
def make_worker_wrapper(
worker_id: int,
heartbeat_array: Any,
) -> Callable[[], None]:
def _wrapper() -> None:
# Import start_listen locally to avoid circular imports
from taskiq.cli.worker.run import start_listen

start_listen(self.args, worker_id, heartbeat_array)

return _wrapper

worker_wrapper = make_worker_wrapper(process, self.heartbeat_array)

work_proc = Process(
target=worker_wrapper,
name=f"worker-{process}",
daemon=False,
)
else:
# Normal case without health checks
work_proc = Process(
target=self.worker_function,
kwargs={"args": self.args},
name=f"worker-{process}",
daemon=False,
)
work_proc.start()
logger.info(
"Started process worker-%d with pid %s ",
Expand All @@ -213,6 +311,37 @@ def prepare_workers(self) -> None:
for worker, event in zip(self.workers, events):
_wait_for_worker_startup(worker, event)

def _start_health_server_if_needed(self) -> None:
"""Start health server in background thread if enabled."""
if self.health_server:
import asyncio
import threading

def run_health_server() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# We know health_server is not None due to outer check
if self.health_server is None:
raise RuntimeError("Health server is unexpectedly None")
loop.run_until_complete(self.health_server.start())
# Keep the server running
loop.run_forever()
except Exception as exc:
logger.error("Health server error: %s", exc)
finally:
loop.close()

health_thread = threading.Thread(target=run_health_server, daemon=True)
health_thread.start()
logger.info("Health check server started in background thread")

def _stop_health_server_if_needed(self) -> None:
"""Stop health server if running."""
if self.health_server:
# The server will be stopped when the daemon thread exits
logger.info("Health check server will stop with main process")

def start(self) -> Optional[int]: # noqa: C901
"""
Start managing child processes.
Expand Down Expand Up @@ -243,13 +372,16 @@ def start(self) -> Optional[int]: # noqa: C901
"""
restarts = 0
self.prepare_workers()

# Start health server if enabled
self._start_health_server_if_needed()
while True:
sleep(1)
reloaded_workers = set()
# We bulk_process all pending events.
while not self.action_queue.empty():
action = self.action_queue.get()
logging.debug(f"Got event: {action}")
logging.debug("Got event: %s", action)
if isinstance(action, ReloadAllAction):
action.handle(
workers_num=len(self.workers),
Expand All @@ -268,18 +400,24 @@ def start(self) -> Optional[int]: # noqa: C901
# If we just reloaded this worker, skip handling.
if action.worker_num in reloaded_workers:
continue
action.handle(self.workers, self.args, self.worker_function)
action.handle(
self.workers,
self.args,
self.worker_function,
self.heartbeat_array,
)
reloaded_workers.add(action.worker_num)
elif isinstance(action, ShutdownAction):
logger.debug("Process manager closed, killing workers.")
for worker in self.workers:
if worker.pid:
os.kill(worker.pid, signal.SIGINT)
self._stop_health_server_if_needed()
return None

for worker_num, worker in enumerate(self.workers):
if not worker.is_alive():
logger.info(f"{worker.name} is dead. Scheduling reload.")
logger.info("%s is dead. Scheduling reload.", worker.name)
self.action_queue.put(
ReloadOneAction(
worker_num=worker_num,
Expand Down
Loading