Skip to content

Commit 1fb82b1

Browse files
committed
feat: webserver lifecycle hooks
1 parent 270d351 commit 1fb82b1

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

redel/server/server.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shutil
44
from contextlib import asynccontextmanager
55
from pathlib import Path
6-
from typing import Annotated, Awaitable, Callable, Collection, TYPE_CHECKING
6+
from typing import Annotated, Any, Awaitable, Callable, Collection, TYPE_CHECKING
77

88
try:
99
from fastapi import Body, FastAPI, HTTPException, WebSocket, WebSocketDisconnect, WebSocketException
@@ -63,6 +63,8 @@ def __init__(
6363
self.interactive_sessions: dict[str, SessionManager] = {}
6464

6565
# webserver
66+
self._startup_hooks = [] # for running async setup code before .serve takes control of async loop
67+
self._shutdown_hooks = []
6668
self.fastapi = FastAPI(lifespan=self._lifespan)
6769
self.setup_app()
6870

@@ -87,17 +89,30 @@ async def create_new_redel(self, **override_kwargs) -> ReDel:
8789
return ReDel(**(self.redel_proto.get_config() | override_kwargs))
8890
return await self.redel_factory(**override_kwargs)
8991

90-
def serve(self, host="127.0.0.1", port=8000, **kwargs):
92+
def serve(
93+
self,
94+
host="127.0.0.1",
95+
port=8000,
96+
startup_hooks: list[Callable[["VizServer"], Any]] = None,
97+
shutdown_hooks: list[Callable[["VizServer"], Any]] = None,
98+
**kwargs,
99+
):
91100
"""Serve this server at the given IP and port. Blocks until interrupted."""
92101
import uvicorn
93102

103+
if startup_hooks:
104+
self._startup_hooks = startup_hooks
105+
if shutdown_hooks:
106+
self._shutdown_hooks = shutdown_hooks
94107
uvicorn.run(self.fastapi, host=host, port=port, **kwargs)
95108

96109
# ==== fastapi ====
97110
@asynccontextmanager
98111
async def _lifespan(self, _: FastAPI):
99112
_ = asyncio.create_task(self.reindex_saves())
113+
await asyncio.gather(*(hook(self) for hook in self._startup_hooks))
100114
yield
115+
await asyncio.gather(*(hook(self) for hook in self._shutdown_hooks))
101116
await asyncio.gather(*(session.close() for session in self.interactive_sessions.values()))
102117

103118
def setup_app(self):

0 commit comments

Comments
 (0)