From 812d3ab55d216ce305efbf2cdbf995de75be617e Mon Sep 17 00:00:00 2001 From: hrodmn Date: Tue, 4 Mar 2025 12:24:17 -0600 Subject: [PATCH 1/4] make postgres settings values optional (but validate later) --- stac_fastapi/pgstac/config.py | 37 ++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/stac_fastapi/pgstac/config.py b/stac_fastapi/pgstac/config.py index 7b1e65d9..4ec38a72 100644 --- a/stac_fastapi/pgstac/config.py +++ b/stac_fastapi/pgstac/config.py @@ -1,6 +1,6 @@ """Postgres API configuration.""" -from typing import List, Type +from typing import List, Optional, Type from urllib.parse import quote_plus as quote from pydantic import BaseModel, field_validator @@ -57,12 +57,12 @@ class Settings(ApiSettings): invalid_id_chars: list of characters that are not allowed in item or collection ids. """ - postgres_user: str - postgres_pass: str - postgres_host_reader: str - postgres_host_writer: str - postgres_port: int - postgres_dbname: str + postgres_user: Optional[str] = None + postgres_pass: Optional[str] = None + postgres_host_reader: Optional[str] = None + postgres_host_writer: Optional[str] = None + postgres_port: Optional[int] = None + postgres_dbname: Optional[str] = None db_min_conn_size: int = 10 db_max_conn_size: int = 10 @@ -93,18 +93,41 @@ def parse_cors_methods(cls, v): @property def reader_connection_string(self): """Create reader psql connection string.""" + self._validate_postgres_settings() return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}" @property def writer_connection_string(self): """Create writer psql connection string.""" + self._validate_postgres_settings() return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}" @property def testing_connection_string(self): """Create testing psql connection string.""" + self._validate_postgres_settings() return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb" + def _validate_postgres_settings(self) -> None: + """Validate that required PostgreSQL settings are configured.""" + required_settings = [ + "postgres_host_writer", + "postgres_host_reader", + "postgres_user", + "postgres_pass", + "postgres_port", + "postgres_dbname", + ] + + missing = [ + setting for setting in required_settings if getattr(self, setting) is None + ] + + if missing: + raise ValueError( + f"Missing required PostgreSQL settings: {', '.join(missing)}", + ) + model_config = SettingsConfigDict( **{**ApiSettings.model_config, **{"env_nested_delimiter": "__"}} ) From a6e121abb02ec35f454fec6e9c305ecbd31d9618 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 5 Mar 2025 11:51:17 -0600 Subject: [PATCH 2/4] split Postgres settings into separate PostgresSettings class --- stac_fastapi/pgstac/config.py | 81 ++++++++++++---------------------- stac_fastapi/pgstac/db.py | 24 ++++++---- tests/api/test_api.py | 18 +++++--- tests/clients/test_postgres.py | 33 +++++++++++++- tests/conftest.py | 22 ++++----- 5 files changed, 100 insertions(+), 78 deletions(-) diff --git a/stac_fastapi/pgstac/config.py b/stac_fastapi/pgstac/config.py index 4ec38a72..cbc4676a 100644 --- a/stac_fastapi/pgstac/config.py +++ b/stac_fastapi/pgstac/config.py @@ -1,10 +1,10 @@ """Postgres API configuration.""" -from typing import List, Optional, Type +from typing import List, Type from urllib.parse import quote_plus as quote from pydantic import BaseModel, field_validator -from pydantic_settings import SettingsConfigDict +from pydantic_settings import BaseSettings, SettingsConfigDict from stac_fastapi.types.config import ApiSettings from stac_fastapi.pgstac.types.base_item_cache import ( @@ -43,7 +43,7 @@ class ServerSettings(BaseModel): model_config = SettingsConfigDict(extra="allow") -class Settings(ApiSettings): +class PostgresSettings(BaseSettings): """Postgres-specific API settings. Attributes: @@ -57,12 +57,12 @@ class Settings(ApiSettings): invalid_id_chars: list of characters that are not allowed in item or collection ids. """ - postgres_user: Optional[str] = None - postgres_pass: Optional[str] = None - postgres_host_reader: Optional[str] = None - postgres_host_writer: Optional[str] = None - postgres_port: Optional[int] = None - postgres_dbname: Optional[str] = None + postgres_user: str + postgres_pass: str + postgres_host_reader: str + postgres_host_writer: str + postgres_port: int + postgres_dbname: str db_min_conn_size: int = 10 db_max_conn_size: int = 10 @@ -71,9 +71,28 @@ class Settings(ApiSettings): server_settings: ServerSettings = ServerSettings() + model_config = {"env_file": ".env", "extra": "ignore"} + + @property + def reader_connection_string(self): + """Create reader psql connection string.""" + return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}" + + @property + def writer_connection_string(self): + """Create writer psql connection string.""" + return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}" + + @property + def testing_connection_string(self): + """Create testing psql connection string.""" + return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb" + + +class Settings(ApiSettings): use_api_hydrate: bool = False - base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache invalid_id_chars: List[str] = DEFAULT_INVALID_ID_CHARS + base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache cors_origins: str = "*" cors_methods: str = "GET,POST,OPTIONS" @@ -89,45 +108,3 @@ def parse_cors_origin(cls, v): def parse_cors_methods(cls, v): """Parse CORS methods.""" return [method.strip() for method in v.split(",")] - - @property - def reader_connection_string(self): - """Create reader psql connection string.""" - self._validate_postgres_settings() - return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}" - - @property - def writer_connection_string(self): - """Create writer psql connection string.""" - self._validate_postgres_settings() - return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}" - - @property - def testing_connection_string(self): - """Create testing psql connection string.""" - self._validate_postgres_settings() - return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb" - - def _validate_postgres_settings(self) -> None: - """Validate that required PostgreSQL settings are configured.""" - required_settings = [ - "postgres_host_writer", - "postgres_host_reader", - "postgres_user", - "postgres_pass", - "postgres_port", - "postgres_dbname", - ] - - missing = [ - setting for setting in required_settings if getattr(self, setting) is None - ] - - if missing: - raise ValueError( - f"Missing required PostgreSQL settings: {', '.join(missing)}", - ) - - model_config = SettingsConfigDict( - **{**ApiSettings.model_config, **{"env_nested_delimiter": "__"}} - ) diff --git a/stac_fastapi/pgstac/db.py b/stac_fastapi/pgstac/db.py index 121d44c0..e0e1420a 100644 --- a/stac_fastapi/pgstac/db.py +++ b/stac_fastapi/pgstac/db.py @@ -25,6 +25,8 @@ NotFoundError, ) +from stac_fastapi.pgstac.config import PostgresSettings + async def con_init(conn): """Use orjson for json returns.""" @@ -46,19 +48,25 @@ async def con_init(conn): async def connect_to_db( - app: FastAPI, get_conn: Optional[ConnectionGetter] = None + app: FastAPI, + get_conn: Optional[ConnectionGetter] = None, + postgres_settings: Optional[PostgresSettings] = None, ) -> None: """Create connection pools & connection retriever on application.""" - settings = app.state.settings - if app.state.settings.testing: - readpool = writepool = settings.testing_connection_string + app_settings = app.state.settings + + if not postgres_settings: + postgres_settings = PostgresSettings() + + if app_settings.testing: + readpool = writepool = postgres_settings.testing_connection_string else: - readpool = settings.reader_connection_string - writepool = settings.writer_connection_string + readpool = postgres_settings.reader_connection_string + writepool = postgres_settings.writer_connection_string db = DB() - app.state.readpool = await db.create_pool(readpool, settings) - app.state.writepool = await db.create_pool(writepool, settings) + app.state.readpool = await db.create_pool(readpool, postgres_settings) + app.state.writepool = await db.create_pool(writepool, postgres_settings) app.state.get_connection = get_conn if get_conn else get_connection diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 743a88e4..1a747270 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -20,6 +20,7 @@ from stac_fastapi.extensions.core.fields import FieldsConformanceClasses from stac_fastapi.types import stac as stac_types +from stac_fastapi.pgstac.config import PostgresSettings from stac_fastapi.pgstac.core import CoreCrudClient, Settings from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.transactions import TransactionsClient @@ -720,13 +721,16 @@ async def get_collection( return await super().get_collection(collection_id, request=request, **kwargs) settings = Settings( + testing=True, + ) + + postgres_settings = PostgresSettings( postgres_user=database.user, postgres_pass=database.password, postgres_host_reader=database.host, postgres_host_writer=database.host, postgres_port=database.port, postgres_dbname=database.dbname, - testing=True, ) extensions = [ @@ -751,7 +755,7 @@ async def get_collection( collections_get_request_model=collection_search_extension.GET, ) app = api.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) try: async with AsyncClient(transport=ASGITransport(app=app)) as client: response = await client.post( @@ -786,15 +790,17 @@ async def test_no_extension( loader.load_items(os.path.join(DATA_DIR, "test_item.json")) settings = Settings( + testing=True, + use_api_hydrate=hydrate, + enable_response_models=validation, + ) + postgres_settings = PostgresSettings( postgres_user=database.user, postgres_pass=database.password, postgres_host_reader=database.host, postgres_host_writer=database.host, postgres_port=database.port, postgres_dbname=database.dbname, - testing=True, - use_api_hydrate=hydrate, - enable_response_models=validation, ) extensions = [] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) @@ -805,7 +811,7 @@ async def test_no_extension( search_post_request_model=post_request_model, ) app = api.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) try: async with AsyncClient(transport=ASGITransport(app=app)) as client: landing = await client.get("http://test/") diff --git a/tests/clients/test_postgres.py b/tests/clients/test_postgres.py index 8a8d5ca5..690e9e26 100644 --- a/tests/clients/test_postgres.py +++ b/tests/clients/test_postgres.py @@ -6,8 +6,10 @@ import pytest from fastapi import Request +from pydantic import ValidationError from stac_pydantic import Collection, Item +from stac_fastapi.pgstac.config import PostgresSettings from stac_fastapi.pgstac.db import close_db_connection, connect_to_db, get_connection # from tests.conftest import MockStarletteRequest @@ -534,14 +536,41 @@ async def custom_get_connection( yield conn +async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch): + """Test that the application starts successfully if the POSTGRES_* environment variables are set""" + monkeypatch.setenv("POSTGRES_USER", database.user) + monkeypatch.setenv("POSTGRES_PASS", database.password) + monkeypatch.setenv("POSTGRES_HOST_READER", database.host) + monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host) + monkeypatch.setenv("POSTGRES_PORT", str(database.port)) + monkeypatch.setenv("POSTGRES_DBNAME", database.dbname) + + await connect_to_db(api_client.app) + + +async def test_db_setup_fails_without_env_vars(api_client): + """Test that the application fails to start if database environment variables are not set.""" + with pytest.raises(ValidationError): + await connect_to_db(api_client.app) + + class TestDbConnect: @pytest.fixture - async def app(self, api_client): + async def app(self, api_client, database): """ app fixture override to setup app with a customized db connection getter """ + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) + logger.debug("Customizing app setup") - await connect_to_db(api_client.app, custom_get_connection) + await connect_to_db(api_client.app, custom_get_connection, postgres_settings) yield api_client.app await close_db_connection(api_client.app) diff --git a/tests/conftest.py b/tests/conftest.py index ce456534..ba49100f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,7 +42,7 @@ from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_pydantic import Collection, Item -from stac_fastapi.pgstac.config import Settings +from stac_fastapi.pgstac.config import PostgresSettings, Settings from stac_fastapi.pgstac.core import CoreCrudClient from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from stac_fastapi.pgstac.extensions import QueryExtension @@ -120,15 +120,9 @@ async def pgstac(database): def api_client(request, database): hydrate, prefix, response_model = request.param api_settings = Settings( - postgres_user=database.user, - postgres_pass=database.password, - postgres_host_reader=database.host, - postgres_host_writer=database.host, - postgres_port=database.port, - postgres_dbname=database.dbname, - use_api_hydrate=hydrate, enable_response_models=response_model, testing=True, + use_api_hydrate=hydrate, ) api_settings.openapi_url = prefix + api_settings.openapi_url @@ -209,11 +203,19 @@ def api_client(request, database): @pytest.fixture(scope="function") -async def app(api_client): +async def app(api_client, database): + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) logger.info("Creating app Fixture") time.time() app = api_client.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) yield app From 2b9d2ec77a26f97826f140ee2138e2bde946c74f Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 5 Mar 2025 11:53:39 -0600 Subject: [PATCH 3/4] update changelog --- CHANGES.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 1f4f1568..4e84606f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] +### Changed + +- move Postgres settings into separate `PostgresSettings` class and defer loading until connecting to database ([#209](https://github.com/stac-utils/stac-fastapi-pgstac/pull/209)) + ## [4.0.2] - 2025-02-18 ### Fixed From ad080fb42ca3b802e4996bfaf1288625151aaa34 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 5 Mar 2025 13:22:37 -0600 Subject: [PATCH 4/4] fix a few tests --- tests/clients/test_postgres.py | 28 ++++++++++++++++------------ tests/conftest.py | 22 ++++++++++++---------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/tests/clients/test_postgres.py b/tests/clients/test_postgres.py index 690e9e26..b3501bcb 100644 --- a/tests/clients/test_postgres.py +++ b/tests/clients/test_postgres.py @@ -525,17 +525,6 @@ async def test_create_bulk_items_id_mismatch( # assert item.collection == coll.id -@asynccontextmanager -async def custom_get_connection( - request: Request, - readwrite: Literal["r", "w"], -): - """An example of customizing the connection getter""" - async with get_connection(request, readwrite) as conn: - await conn.execute("SELECT set_config('api.test', 'added-config', false)") - yield conn - - async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch): """Test that the application starts successfully if the POSTGRES_* environment variables are set""" monkeypatch.setenv("POSTGRES_USER", database.user) @@ -546,12 +535,27 @@ async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch): monkeypatch.setenv("POSTGRES_DBNAME", database.dbname) await connect_to_db(api_client.app) + await close_db_connection(api_client.app) async def test_db_setup_fails_without_env_vars(api_client): """Test that the application fails to start if database environment variables are not set.""" - with pytest.raises(ValidationError): + try: await connect_to_db(api_client.app) + except ValidationError: + await close_db_connection(api_client.app) + pytest.raises(ValidationError) + + +@asynccontextmanager +async def custom_get_connection( + request: Request, + readwrite: Literal["r", "w"], +): + """An example of customizing the connection getter""" + async with get_connection(request, readwrite) as conn: + await conn.execute("SELECT set_config('api.test', 'added-config', false)") + yield conn class TestDbConnect: diff --git a/tests/conftest.py b/tests/conftest.py index ba49100f..79f64e8f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -117,7 +117,7 @@ async def pgstac(database): ], scope="session", ) -def api_client(request, database): +def api_client(request): hydrate, prefix, response_model = request.param api_settings = Settings( enable_response_models=response_model, @@ -298,14 +298,8 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection): @pytest.fixture( scope="session", ) -def api_client_no_ext(database): +def api_client_no_ext(): api_settings = Settings( - postgres_user=database.user, - postgres_pass=database.password, - postgres_host_reader=database.host, - postgres_host_writer=database.host, - postgres_port=database.port, - postgres_dbname=database.dbname, testing=True, ) return StacApi( @@ -318,11 +312,19 @@ def api_client_no_ext(database): @pytest.fixture(scope="function") -async def app_no_ext(api_client_no_ext): +async def app_no_ext(api_client_no_ext, database): + postgres_settings = PostgresSettings( + postgres_user=database.user, + postgres_pass=database.password, + postgres_host_reader=database.host, + postgres_host_writer=database.host, + postgres_port=database.port, + postgres_dbname=database.dbname, + ) logger.info("Creating app Fixture") time.time() app = api_client_no_ext.app - await connect_to_db(app) + await connect_to_db(app, postgres_settings=postgres_settings) yield app