99import textwrap
1010import threading
1111import traceback
12+ from contextlib import asynccontextmanager
1213from datetime import datetime , timezone
1314from enum import Enum , auto , unique
14- from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Optional , Type
15+ from typing import (
16+ TYPE_CHECKING ,
17+ Any ,
18+ AsyncGenerator ,
19+ Awaitable ,
20+ Callable ,
21+ Dict ,
22+ Optional ,
23+ Type ,
24+ )
1525
1626import structlog
1727import uvicorn
@@ -120,9 +130,32 @@ def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-s
120130 is_build : bool = False ,
121131 await_explicit_shutdown : bool = False , # pylint: disable=redefined-outer-name
122132) -> MyFastAPI :
133+ started_at = datetime .now (tz = timezone .utc )
134+
135+ @asynccontextmanager
136+ async def lifespan (app : MyFastAPI ) -> AsyncGenerator [None , None ]:
137+ # Startup code (was previously in @app.on_event("startup"))
138+ # check for early setup failures
139+ if (
140+ app .state .setup_result
141+ and app .state .setup_result .status == schema .Status .FAILED
142+ ):
143+ # signal shutdown if interactive run
144+ if shutdown_event and not await_explicit_shutdown :
145+ shutdown_event .set ()
146+ else :
147+ setup_task = runner .setup ()
148+ setup_task .add_done_callback (_handle_setup_done )
149+
150+ yield
151+
152+ # Shutdown code (was previously in @app.on_event("shutdown"))
153+ worker .terminate ()
154+
123155 app = MyFastAPI ( # pylint: disable=redefined-outer-name
124156 title = "Cog" , # TODO: mention model name?
125157 # version=None # TODO
158+ lifespan = lifespan ,
126159 )
127160
128161 def custom_openapi () -> Dict [str , Any ]:
@@ -149,7 +182,6 @@ def custom_openapi() -> Dict[str, Any]:
149182
150183 app .state .health = Health .STARTING
151184 app .state .setup_result = None
152- started_at = datetime .now (tz = timezone .utc )
153185
154186 # shutdown is needed no matter what happens
155187 @app .post ("/shutdown" )
@@ -318,24 +350,6 @@ def cancel_training(
318350 add_setup_failed_routes (app , started_at , msg )
319351 return app
320352
321- @app .on_event ("startup" )
322- def startup () -> None :
323- # check for early setup failures
324- if (
325- app .state .setup_result
326- and app .state .setup_result .status == schema .Status .FAILED
327- ):
328- # signal shutdown if interactive run
329- if shutdown_event and not await_explicit_shutdown :
330- shutdown_event .set ()
331- else :
332- setup_task = runner .setup ()
333- setup_task .add_done_callback (_handle_setup_done )
334-
335- @app .on_event ("shutdown" )
336- def shutdown () -> None :
337- worker .terminate ()
338-
339353 @app .get ("/" )
340354 async def root () -> Any :
341355 return index_document
0 commit comments