diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index f64095f..13f2c05 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -6,8 +6,10 @@ """ import os +from contextlib import asynccontextmanager from brotli_asgi import BrotliMiddleware +from fastapi import FastAPI from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware @@ -18,6 +20,7 @@ create_post_request_model, create_request_model, ) +from stac_fastapi.api.openapi import update_openapi from stac_fastapi.extensions.core import ( FieldsExtension, FilterExtension, @@ -101,7 +104,26 @@ post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """FastAPI Lifespan.""" + await connect_to_db(app) + yield + await close_db_connection(app) + + +fastapp = FastAPI( + openapi_url=settings.openapi_url, + docs_url=settings.docs_url, + redoc_url=None, + root_path=getattr(settings, "root_path", None), + lifespan=lifespan, +) + + api = StacApi( + app=update_openapi(fastapp), settings=settings, extensions=extensions + [collection_search_extension] if collection_search_extension @@ -125,18 +147,6 @@ app = api.app -@app.on_event("startup") -async def startup_event(): - """Connect to database on startup.""" - await connect_to_db(app) - - -@app.on_event("shutdown") -async def shutdown_event(): - """Close database connection.""" - await close_db_connection(app) - - def run(): """Run app from command line using uvicorn if available.""" try: