diff --git a/CHANGES.md b/CHANGES.md index ec6c5966..cca451a2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Changed + +- move Postgres settings into separate `PostgresSettings` class and defer loading until connecting to database ([#209](https://github.com/stac-utils/stac-fastapi-pgstac/pull/209)) + ## [4.0.3] - 2025-03-10 ### Fixed diff --git a/stac_fastapi/pgstac/config.py b/stac_fastapi/pgstac/config.py index 7b1e65d9..cbc4676a 100644 --- a/stac_fastapi/pgstac/config.py +++ b/stac_fastapi/pgstac/config.py @@ -4,7 +4,7 @@ from urllib.parse import quote_plus as quote from pydantic import BaseModel, field_validator -from pydantic_settings import SettingsConfigDict +from pydantic_settings import BaseSettings, SettingsConfigDict from stac_fastapi.types.config import ApiSettings from stac_fastapi.pgstac.types.base_item_cache import ( @@ -43,7 +43,7 @@ class ServerSettings(BaseModel): model_config = SettingsConfigDict(extra="allow") -class Settings(ApiSettings): +class PostgresSettings(BaseSettings): """Postgres-specific API settings. Attributes: @@ -71,9 +71,28 @@ class Settings(ApiSettings): server_settings: ServerSettings = ServerSettings() + model_config = {"env_file": ".env", "extra": "ignore"} + + @property + def reader_connection_string(self): + """Create reader psql connection string.""" + return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}" + + @property + def writer_connection_string(self): + """Create writer psql connection string.""" + return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}" + + @property + def testing_connection_string(self): + """Create testing psql connection string.""" + return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb" + + +class Settings(ApiSettings): use_api_hydrate: bool = False - base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache invalid_id_chars: List[str] = DEFAULT_INVALID_ID_CHARS + base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache cors_origins: str = "*" cors_methods: str = "GET,POST,OPTIONS" @@ -89,22 +108,3 @@ def parse_cors_origin(cls, v): def parse_cors_methods(cls, v): """Parse CORS methods.""" return [method.strip() for method in v.split(",")] - - @property - def reader_connection_string(self): - """Create reader psql connection string.""" - return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}" - - @property - def writer_connection_string(self): - """Create writer psql connection string.""" - return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}" - - @property - def testing_connection_string(self): - """Create testing psql connection string.""" - return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb" - - model_config = SettingsConfigDict( - **{**ApiSettings.model_config, **{"env_nested_delimiter": "__"}} - ) diff --git a/stac_fastapi/pgstac/db.py b/stac_fastapi/pgstac/db.py index 121d44c0..e0e1420a 100644 --- a/stac_fastapi/pgstac/db.py +++ b/stac_fastapi/pgstac/db.py @@ -25,6 +25,8 @@ NotFoundError, ) +from stac_fastapi.pgstac.config import PostgresSettings + async def con_init(conn): """Use orjson for json returns.""" @@ -46,19 +48,25 @@ async def con_init(conn): async def connect_to_db( - app: FastAPI, get_conn: Optional[ConnectionGetter] = None + app: FastAPI, + get_conn: Optional[ConnectionGetter] = None, + postgres_settings: Optional[PostgresSettings] = None, ) -> None: """Create connection pools & connection retriever on application.""" - settings = app.state.settings - if app.state.settings.testing: - readpool = writepool = settings.testing_connection_string + app_settings = app.state.settings + + if not postgres_settings: + postgres_settings = PostgresSettings() + + if app_settings.testing: + readpool = writepool = postgres_settings.testing_connection_string else: - readpool = settings.reader_connection_string - writepool = settings.writer_connection_string + readpool = postgres_settings.reader_connection_string + writepool = postgres_settings.writer_connection_string db = DB() - app.state.readpool = await db.create_pool(readpool, settings) - app.state.writepool = await db.create_pool(writepool, settings) + app.state.readpool = await db.create_pool(readpool, postgres_settings) + app.state.writepool = await db.create_pool(writepool, postgres_settings) app.state.get_connection = get_conn if get_conn else get_connection diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 743a88e4..1a747270 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -20,6 +20,7 @@ from stac_fastapi.extensions.core.fields import FieldsConformanceClasses from stac_fastapi.types import stac as stac_types +from stac_fastapi.pgstac.config import PostgresSettings from stac_fastapi.pgstac.core import CoreCrudClient, Settings from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.transactions import TransactionsClient @@ -720,13 +721,16 @@ async def get_collection( return await super().get_collection(collection_id, request=request, **kwargs) settings = Settings( + testing=True, + ) + + postgres_settings = PostgresSettings( postgres_user=database.user, postgres_pass=database.password, postgres_host_reader=database.host, postgres_host_writer=database.host, postgres_port=database.port, postgres_dbname=database.dbname, - testing=True, ) extensions = [ @@ -751,7 +755,7 @@ async def get_collection( collections_get_request_model=collection_search_extension.GET, ) app = api.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) try: async with AsyncClient(transport=ASGITransport(app=app)) as client: response = await client.post( @@ -786,15 +790,17 @@ async def test_no_extension( loader.load_items(os.path.join(DATA_DIR, "test_item.json")) settings = Settings( + testing=True, + use_api_hydrate=hydrate, + enable_response_models=validation, + ) + postgres_settings = PostgresSettings( postgres_user=database.user, postgres_pass=database.password, postgres_host_reader=database.host, postgres_host_writer=database.host, postgres_port=database.port, postgres_dbname=database.dbname, - testing=True, - use_api_hydrate=hydrate, - enable_response_models=validation, ) extensions = [] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) @@ -805,7 +811,7 @@ async def test_no_extension( search_post_request_model=post_request_model, ) app = api.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) try: async with AsyncClient(transport=ASGITransport(app=app)) as client: landing = await client.get("http://test/") diff --git a/tests/clients/test_postgres.py b/tests/clients/test_postgres.py index 8a8d5ca5..b3501bcb 100644 --- a/tests/clients/test_postgres.py +++ b/tests/clients/test_postgres.py @@ -6,8 +6,10 @@ import pytest from fastapi import Request +from pydantic import ValidationError from stac_pydantic import Collection, Item +from stac_fastapi.pgstac.config import PostgresSettings from stac_fastapi.pgstac.db import close_db_connection, connect_to_db, get_connection # from tests.conftest import MockStarletteRequest @@ -523,6 +525,28 @@ async def test_create_bulk_items_id_mismatch( # assert item.collection == coll.id +async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch): + """Test that the application starts successfully if the POSTGRES_* environment variables are set""" + monkeypatch.setenv("POSTGRES_USER", database.user) + monkeypatch.setenv("POSTGRES_PASS", database.password) + monkeypatch.setenv("POSTGRES_HOST_READER", database.host) + monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host) + monkeypatch.setenv("POSTGRES_PORT", str(database.port)) + monkeypatch.setenv("POSTGRES_DBNAME", database.dbname) + + await connect_to_db(api_client.app) + await close_db_connection(api_client.app) + + +async def test_db_setup_fails_without_env_vars(api_client): + """Test that the application fails to start if database environment variables are not set.""" + try: + await connect_to_db(api_client.app) + except ValidationError: + await close_db_connection(api_client.app) + pytest.raises(ValidationError) + + @asynccontextmanager async def custom_get_connection( request: Request, @@ -536,12 +560,21 @@ async def custom_get_connection( class TestDbConnect: @pytest.fixture - async def app(self, api_client): + async def app(self, api_client, database): """ app fixture override to setup app with a customized db connection getter """ + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) + logger.debug("Customizing app setup") - await connect_to_db(api_client.app, custom_get_connection) + await connect_to_db(api_client.app, custom_get_connection, postgres_settings) yield api_client.app await close_db_connection(api_client.app) diff --git a/tests/conftest.py b/tests/conftest.py index ec411699..7944e8de 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,7 @@ from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_pydantic import Collection, Item -from stac_fastapi.pgstac.config import Settings +from stac_fastapi.pgstac.config import PostgresSettings, Settings from stac_fastapi.pgstac.core import CoreCrudClient from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.extensions import QueryExtension @@ -111,18 +111,12 @@ async def pgstac(database): ], scope="session", ) -def api_client(request, database): +def api_client(request): hydrate, prefix, response_model = request.param api_settings = Settings( - postgres_user=database.user, - postgres_pass=database.password, - postgres_host_reader=database.host, - postgres_host_writer=database.host, - postgres_port=database.port, - postgres_dbname=database.dbname, - use_api_hydrate=hydrate, enable_response_models=response_model, testing=True, + use_api_hydrate=hydrate, ) api_settings.openapi_url = prefix + api_settings.openapi_url @@ -203,11 +197,19 @@ def api_client(request, database): @pytest.fixture(scope="function") -async def app(api_client): +async def app(api_client, database): + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) logger.info("Creating app Fixture") time.time() app = api_client.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) yield app @@ -290,14 +292,8 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection): @pytest.fixture( scope="session", ) -def api_client_no_ext(database): +def api_client_no_ext(): api_settings = Settings( - postgres_user=database.user, - postgres_pass=database.password, - postgres_host_reader=database.host, - postgres_host_writer=database.host, - postgres_port=database.port, - postgres_dbname=database.dbname, testing=True, ) return StacApi( @@ -310,11 +306,19 @@ def api_client_no_ext(database): @pytest.fixture(scope="function") -async def app_no_ext(api_client_no_ext): +async def app_no_ext(api_client_no_ext, database): + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) logger.info("Creating app Fixture") time.time() app = api_client_no_ext.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) yield app