Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Changed

- move Postgres settings into separate `PostgresSettings` class and defer loading until connecting to database ([#209](https://github.com/stac-utils/stac-fastapi-pgstac/pull/209))

## [4.0.3] - 2025-03-10

### Fixed
Expand Down
44 changes: 22 additions & 22 deletions stac_fastapi/pgstac/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from urllib.parse import quote_plus as quote

from pydantic import BaseModel, field_validator
from pydantic_settings import SettingsConfigDict
from pydantic_settings import BaseSettings, SettingsConfigDict
from stac_fastapi.types.config import ApiSettings

from stac_fastapi.pgstac.types.base_item_cache import (
Expand Down Expand Up @@ -43,7 +43,7 @@ class ServerSettings(BaseModel):
model_config = SettingsConfigDict(extra="allow")


class Settings(ApiSettings):
class PostgresSettings(BaseSettings):
"""Postgres-specific API settings.

Attributes:
Expand Down Expand Up @@ -71,9 +71,28 @@ class Settings(ApiSettings):

server_settings: ServerSettings = ServerSettings()

model_config = {"env_file": ".env", "extra": "ignore"}

@property
def reader_connection_string(self):
"""Create reader psql connection string."""
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"

@property
def writer_connection_string(self):
"""Create writer psql connection string."""
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"

@property
def testing_connection_string(self):
"""Create testing psql connection string."""
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"


class Settings(ApiSettings):
use_api_hydrate: bool = False
base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache
invalid_id_chars: List[str] = DEFAULT_INVALID_ID_CHARS
base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache

cors_origins: str = "*"
cors_methods: str = "GET,POST,OPTIONS"
Expand All @@ -89,22 +108,3 @@ def parse_cors_origin(cls, v):
def parse_cors_methods(cls, v):
"""Parse CORS methods."""
return [method.strip() for method in v.split(",")]

@property
def reader_connection_string(self):
"""Create reader psql connection string."""
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"

@property
def writer_connection_string(self):
"""Create writer psql connection string."""
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"

@property
def testing_connection_string(self):
"""Create testing psql connection string."""
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"

model_config = SettingsConfigDict(
**{**ApiSettings.model_config, **{"env_nested_delimiter": "__"}}
)
24 changes: 16 additions & 8 deletions stac_fastapi/pgstac/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
NotFoundError,
)

from stac_fastapi.pgstac.config import PostgresSettings


async def con_init(conn):
"""Use orjson for json returns."""
Expand All @@ -46,19 +48,25 @@ async def con_init(conn):


async def connect_to_db(
app: FastAPI, get_conn: Optional[ConnectionGetter] = None
app: FastAPI,
get_conn: Optional[ConnectionGetter] = None,
postgres_settings: Optional[PostgresSettings] = None,
) -> None:
"""Create connection pools & connection retriever on application."""
settings = app.state.settings
if app.state.settings.testing:
readpool = writepool = settings.testing_connection_string
app_settings = app.state.settings

if not postgres_settings:
postgres_settings = PostgresSettings()

if app_settings.testing:
readpool = writepool = postgres_settings.testing_connection_string
else:
readpool = settings.reader_connection_string
writepool = settings.writer_connection_string
readpool = postgres_settings.reader_connection_string
writepool = postgres_settings.writer_connection_string

db = DB()
app.state.readpool = await db.create_pool(readpool, settings)
app.state.writepool = await db.create_pool(writepool, settings)
app.state.readpool = await db.create_pool(readpool, postgres_settings)
app.state.writepool = await db.create_pool(writepool, postgres_settings)
app.state.get_connection = get_conn if get_conn else get_connection


Expand Down
18 changes: 12 additions & 6 deletions tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from stac_fastapi.extensions.core.fields import FieldsConformanceClasses
from stac_fastapi.types import stac as stac_types

from stac_fastapi.pgstac.config import PostgresSettings
from stac_fastapi.pgstac.core import CoreCrudClient, Settings
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.transactions import TransactionsClient
Expand Down Expand Up @@ -720,13 +721,16 @@ async def get_collection(
return await super().get_collection(collection_id, request=request, **kwargs)

settings = Settings(
testing=True,
)

postgres_settings = PostgresSettings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
testing=True,
)

extensions = [
Expand All @@ -751,7 +755,7 @@ async def get_collection(
collections_get_request_model=collection_search_extension.GET,
)
app = api.app
await connect_to_db(app)
await connect_to_db(app, postgres_settings=postgres_settings)
try:
async with AsyncClient(transport=ASGITransport(app=app)) as client:
response = await client.post(
Expand Down Expand Up @@ -786,15 +790,17 @@ async def test_no_extension(
loader.load_items(os.path.join(DATA_DIR, "test_item.json"))

settings = Settings(
testing=True,
use_api_hydrate=hydrate,
enable_response_models=validation,
)
postgres_settings = PostgresSettings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
testing=True,
use_api_hydrate=hydrate,
enable_response_models=validation,
)
extensions = []
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
Expand All @@ -805,7 +811,7 @@ async def test_no_extension(
search_post_request_model=post_request_model,
)
app = api.app
await connect_to_db(app)
await connect_to_db(app, postgres_settings=postgres_settings)
try:
async with AsyncClient(transport=ASGITransport(app=app)) as client:
landing = await client.get("http://test/")
Expand Down
37 changes: 35 additions & 2 deletions tests/clients/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import pytest
from fastapi import Request
from pydantic import ValidationError
from stac_pydantic import Collection, Item

from stac_fastapi.pgstac.config import PostgresSettings
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db, get_connection

# from tests.conftest import MockStarletteRequest
Expand Down Expand Up @@ -523,6 +525,28 @@ async def test_create_bulk_items_id_mismatch(
# assert item.collection == coll.id


async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch):
"""Test that the application starts successfully if the POSTGRES_* environment variables are set"""
monkeypatch.setenv("POSTGRES_USER", database.user)
monkeypatch.setenv("POSTGRES_PASS", database.password)
monkeypatch.setenv("POSTGRES_HOST_READER", database.host)
monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host)
monkeypatch.setenv("POSTGRES_PORT", str(database.port))
monkeypatch.setenv("POSTGRES_DBNAME", database.dbname)

await connect_to_db(api_client.app)
await close_db_connection(api_client.app)


async def test_db_setup_fails_without_env_vars(api_client):
"""Test that the application fails to start if database environment variables are not set."""
try:
await connect_to_db(api_client.app)
except ValidationError:
await close_db_connection(api_client.app)
pytest.raises(ValidationError)


@asynccontextmanager
async def custom_get_connection(
request: Request,
Expand All @@ -536,12 +560,21 @@ async def custom_get_connection(

class TestDbConnect:
@pytest.fixture
async def app(self, api_client):
async def app(self, api_client, database):
"""
app fixture override to setup app with a customized db connection getter
"""
postgres_settings = PostgresSettings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
)

logger.debug("Customizing app setup")
await connect_to_db(api_client.app, custom_get_connection)
await connect_to_db(api_client.app, custom_get_connection, postgres_settings)
yield api_client.app
await close_db_connection(api_client.app)

Expand Down
44 changes: 24 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_pydantic import Collection, Item

from stac_fastapi.pgstac.config import Settings
from stac_fastapi.pgstac.config import PostgresSettings, Settings
from stac_fastapi.pgstac.core import CoreCrudClient
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
from stac_fastapi.pgstac.extensions import QueryExtension
Expand Down Expand Up @@ -111,18 +111,12 @@ async def pgstac(database):
],
scope="session",
)
def api_client(request, database):
def api_client(request):
hydrate, prefix, response_model = request.param
api_settings = Settings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
use_api_hydrate=hydrate,
enable_response_models=response_model,
testing=True,
use_api_hydrate=hydrate,
)

api_settings.openapi_url = prefix + api_settings.openapi_url
Expand Down Expand Up @@ -203,11 +197,19 @@ def api_client(request, database):


@pytest.fixture(scope="function")
async def app(api_client):
async def app(api_client, database):
postgres_settings = PostgresSettings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
)
logger.info("Creating app Fixture")
time.time()
app = api_client.app
await connect_to_db(app)
await connect_to_db(app, postgres_settings=postgres_settings)

yield app

Expand Down Expand Up @@ -290,14 +292,8 @@ async def load_test2_item(app_client, load_test_data, load_test2_collection):
@pytest.fixture(
scope="session",
)
def api_client_no_ext(database):
def api_client_no_ext():
api_settings = Settings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
testing=True,
)
return StacApi(
Expand All @@ -310,11 +306,19 @@ def api_client_no_ext(database):


@pytest.fixture(scope="function")
async def app_no_ext(api_client_no_ext):
async def app_no_ext(api_client_no_ext, database):
postgres_settings = PostgresSettings(
postgres_user=database.user,
postgres_pass=database.password,
postgres_host_reader=database.host,
postgres_host_writer=database.host,
postgres_port=database.port,
postgres_dbname=database.dbname,
)
logger.info("Creating app Fixture")
time.time()
app = api_client_no_ext.app
await connect_to_db(app)
await connect_to_db(app, postgres_settings=postgres_settings)

yield app

Expand Down
Loading