Skip to content

Commit a6e121a

Browse files
committed
split Postgres settings into separate PostgresSettings class
1 parent 812d3ab commit a6e121a

File tree

5 files changed

+100
-78
lines changed

5 files changed

+100
-78
lines changed

stac_fastapi/pgstac/config.py

Lines changed: 29 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Postgres API configuration."""
22

3-
from typing import List, Optional, Type
3+
from typing import List, Type
44
from urllib.parse import quote_plus as quote
55

66
from pydantic import BaseModel, field_validator
7-
from pydantic_settings import SettingsConfigDict
7+
from pydantic_settings import BaseSettings, SettingsConfigDict
88
from stac_fastapi.types.config import ApiSettings
99

1010
from stac_fastapi.pgstac.types.base_item_cache import (
@@ -43,7 +43,7 @@ class ServerSettings(BaseModel):
4343
model_config = SettingsConfigDict(extra="allow")
4444

4545

46-
class Settings(ApiSettings):
46+
class PostgresSettings(BaseSettings):
4747
"""Postgres-specific API settings.
4848
4949
Attributes:
@@ -57,12 +57,12 @@ class Settings(ApiSettings):
5757
invalid_id_chars: list of characters that are not allowed in item or collection ids.
5858
"""
5959

60-
postgres_user: Optional[str] = None
61-
postgres_pass: Optional[str] = None
62-
postgres_host_reader: Optional[str] = None
63-
postgres_host_writer: Optional[str] = None
64-
postgres_port: Optional[int] = None
65-
postgres_dbname: Optional[str] = None
60+
postgres_user: str
61+
postgres_pass: str
62+
postgres_host_reader: str
63+
postgres_host_writer: str
64+
postgres_port: int
65+
postgres_dbname: str
6666

6767
db_min_conn_size: int = 10
6868
db_max_conn_size: int = 10
@@ -71,9 +71,28 @@ class Settings(ApiSettings):
7171

7272
server_settings: ServerSettings = ServerSettings()
7373

74+
model_config = {"env_file": ".env", "extra": "ignore"}
75+
76+
@property
77+
def reader_connection_string(self):
78+
"""Create reader psql connection string."""
79+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"
80+
81+
@property
82+
def writer_connection_string(self):
83+
"""Create writer psql connection string."""
84+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"
85+
86+
@property
87+
def testing_connection_string(self):
88+
"""Create testing psql connection string."""
89+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"
90+
91+
92+
class Settings(ApiSettings):
7493
use_api_hydrate: bool = False
75-
base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache
7694
invalid_id_chars: List[str] = DEFAULT_INVALID_ID_CHARS
95+
base_item_cache: Type[BaseItemCache] = DefaultBaseItemCache
7796

7897
cors_origins: str = "*"
7998
cors_methods: str = "GET,POST,OPTIONS"
@@ -89,45 +108,3 @@ def parse_cors_origin(cls, v):
89108
def parse_cors_methods(cls, v):
90109
"""Parse CORS methods."""
91110
return [method.strip() for method in v.split(",")]
92-
93-
@property
94-
def reader_connection_string(self):
95-
"""Create reader psql connection string."""
96-
self._validate_postgres_settings()
97-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"
98-
99-
@property
100-
def writer_connection_string(self):
101-
"""Create writer psql connection string."""
102-
self._validate_postgres_settings()
103-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"
104-
105-
@property
106-
def testing_connection_string(self):
107-
"""Create testing psql connection string."""
108-
self._validate_postgres_settings()
109-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"
110-
111-
def _validate_postgres_settings(self) -> None:
112-
"""Validate that required PostgreSQL settings are configured."""
113-
required_settings = [
114-
"postgres_host_writer",
115-
"postgres_host_reader",
116-
"postgres_user",
117-
"postgres_pass",
118-
"postgres_port",
119-
"postgres_dbname",
120-
]
121-
122-
missing = [
123-
setting for setting in required_settings if getattr(self, setting) is None
124-
]
125-
126-
if missing:
127-
raise ValueError(
128-
f"Missing required PostgreSQL settings: {', '.join(missing)}",
129-
)
130-
131-
model_config = SettingsConfigDict(
132-
**{**ApiSettings.model_config, **{"env_nested_delimiter": "__"}}
133-
)

stac_fastapi/pgstac/db.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
NotFoundError,
2626
)
2727

28+
from stac_fastapi.pgstac.config import PostgresSettings
29+
2830

2931
async def con_init(conn):
3032
"""Use orjson for json returns."""
@@ -46,19 +48,25 @@ async def con_init(conn):
4648

4749

4850
async def connect_to_db(
49-
app: FastAPI, get_conn: Optional[ConnectionGetter] = None
51+
app: FastAPI,
52+
get_conn: Optional[ConnectionGetter] = None,
53+
postgres_settings: Optional[PostgresSettings] = None,
5054
) -> None:
5155
"""Create connection pools & connection retriever on application."""
52-
settings = app.state.settings
53-
if app.state.settings.testing:
54-
readpool = writepool = settings.testing_connection_string
56+
app_settings = app.state.settings
57+
58+
if not postgres_settings:
59+
postgres_settings = PostgresSettings()
60+
61+
if app_settings.testing:
62+
readpool = writepool = postgres_settings.testing_connection_string
5563
else:
56-
readpool = settings.reader_connection_string
57-
writepool = settings.writer_connection_string
64+
readpool = postgres_settings.reader_connection_string
65+
writepool = postgres_settings.writer_connection_string
5866

5967
db = DB()
60-
app.state.readpool = await db.create_pool(readpool, settings)
61-
app.state.writepool = await db.create_pool(writepool, settings)
68+
app.state.readpool = await db.create_pool(readpool, postgres_settings)
69+
app.state.writepool = await db.create_pool(writepool, postgres_settings)
6270
app.state.get_connection = get_conn if get_conn else get_connection
6371

6472

tests/api/test_api.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from stac_fastapi.extensions.core.fields import FieldsConformanceClasses
2121
from stac_fastapi.types import stac as stac_types
2222

23+
from stac_fastapi.pgstac.config import PostgresSettings
2324
from stac_fastapi.pgstac.core import CoreCrudClient, Settings
2425
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
2526
from stac_fastapi.pgstac.transactions import TransactionsClient
@@ -720,13 +721,16 @@ async def get_collection(
720721
return await super().get_collection(collection_id, request=request, **kwargs)
721722

722723
settings = Settings(
724+
testing=True,
725+
)
726+
727+
postgres_settings = PostgresSettings(
723728
postgres_user=database.user,
724729
postgres_pass=database.password,
725730
postgres_host_reader=database.host,
726731
postgres_host_writer=database.host,
727732
postgres_port=database.port,
728733
postgres_dbname=database.dbname,
729-
testing=True,
730734
)
731735

732736
extensions = [
@@ -751,7 +755,7 @@ async def get_collection(
751755
collections_get_request_model=collection_search_extension.GET,
752756
)
753757
app = api.app
754-
await connect_to_db(app)
758+
await connect_to_db(app, postgres_settings=postgres_settings)
755759
try:
756760
async with AsyncClient(transport=ASGITransport(app=app)) as client:
757761
response = await client.post(
@@ -786,15 +790,17 @@ async def test_no_extension(
786790
loader.load_items(os.path.join(DATA_DIR, "test_item.json"))
787791

788792
settings = Settings(
793+
testing=True,
794+
use_api_hydrate=hydrate,
795+
enable_response_models=validation,
796+
)
797+
postgres_settings = PostgresSettings(
789798
postgres_user=database.user,
790799
postgres_pass=database.password,
791800
postgres_host_reader=database.host,
792801
postgres_host_writer=database.host,
793802
postgres_port=database.port,
794803
postgres_dbname=database.dbname,
795-
testing=True,
796-
use_api_hydrate=hydrate,
797-
enable_response_models=validation,
798804
)
799805
extensions = []
800806
post_request_model = create_post_request_model(extensions, base_model=PgstacSearch)
@@ -805,7 +811,7 @@ async def test_no_extension(
805811
search_post_request_model=post_request_model,
806812
)
807813
app = api.app
808-
await connect_to_db(app)
814+
await connect_to_db(app, postgres_settings=postgres_settings)
809815
try:
810816
async with AsyncClient(transport=ASGITransport(app=app)) as client:
811817
landing = await client.get("http://test/")

tests/clients/test_postgres.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import pytest
88
from fastapi import Request
9+
from pydantic import ValidationError
910
from stac_pydantic import Collection, Item
1011

12+
from stac_fastapi.pgstac.config import PostgresSettings
1113
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db, get_connection
1214

1315
# from tests.conftest import MockStarletteRequest
@@ -534,14 +536,41 @@ async def custom_get_connection(
534536
yield conn
535537

536538

539+
async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch):
540+
"""Test that the application starts successfully if the POSTGRES_* environment variables are set"""
541+
monkeypatch.setenv("POSTGRES_USER", database.user)
542+
monkeypatch.setenv("POSTGRES_PASS", database.password)
543+
monkeypatch.setenv("POSTGRES_HOST_READER", database.host)
544+
monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host)
545+
monkeypatch.setenv("POSTGRES_PORT", str(database.port))
546+
monkeypatch.setenv("POSTGRES_DBNAME", database.dbname)
547+
548+
await connect_to_db(api_client.app)
549+
550+
551+
async def test_db_setup_fails_without_env_vars(api_client):
552+
"""Test that the application fails to start if database environment variables are not set."""
553+
with pytest.raises(ValidationError):
554+
await connect_to_db(api_client.app)
555+
556+
537557
class TestDbConnect:
538558
@pytest.fixture
539-
async def app(self, api_client):
559+
async def app(self, api_client, database):
540560
"""
541561
app fixture override to setup app with a customized db connection getter
542562
"""
563+
postgres_settings = PostgresSettings(
564+
postgres_user=database.user,
565+
postgres_pass=database.password,
566+
postgres_host_reader=database.host,
567+
postgres_host_writer=database.host,
568+
postgres_port=database.port,
569+
postgres_dbname=database.dbname,
570+
)
571+
543572
logger.debug("Customizing app setup")
544-
await connect_to_db(api_client.app, custom_get_connection)
573+
await connect_to_db(api_client.app, custom_get_connection, postgres_settings)
545574
yield api_client.app
546575
await close_db_connection(api_client.app)
547576

tests/conftest.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from stac_fastapi.extensions.third_party import BulkTransactionExtension
4343
from stac_pydantic import Collection, Item
4444

45-
from stac_fastapi.pgstac.config import Settings
45+
from stac_fastapi.pgstac.config import PostgresSettings, Settings
4646
from stac_fastapi.pgstac.core import CoreCrudClient
4747
from stac_fastapi.pgstac.db import close_db_connection, connect_to_db
4848
from stac_fastapi.pgstac.extensions import QueryExtension
@@ -120,15 +120,9 @@ async def pgstac(database):
120120
def api_client(request, database):
121121
hydrate, prefix, response_model = request.param
122122
api_settings = Settings(
123-
postgres_user=database.user,
124-
postgres_pass=database.password,
125-
postgres_host_reader=database.host,
126-
postgres_host_writer=database.host,
127-
postgres_port=database.port,
128-
postgres_dbname=database.dbname,
129-
use_api_hydrate=hydrate,
130123
enable_response_models=response_model,
131124
testing=True,
125+
use_api_hydrate=hydrate,
132126
)
133127

134128
api_settings.openapi_url = prefix + api_settings.openapi_url
@@ -209,11 +203,19 @@ def api_client(request, database):
209203

210204

211205
@pytest.fixture(scope="function")
212-
async def app(api_client):
206+
async def app(api_client, database):
207+
postgres_settings = PostgresSettings(
208+
postgres_user=database.user,
209+
postgres_pass=database.password,
210+
postgres_host_reader=database.host,
211+
postgres_host_writer=database.host,
212+
postgres_port=database.port,
213+
postgres_dbname=database.dbname,
214+
)
213215
logger.info("Creating app Fixture")
214216
time.time()
215217
app = api_client.app
216-
await connect_to_db(app)
218+
await connect_to_db(app, postgres_settings=postgres_settings)
217219

218220
yield app
219221

0 commit comments

Comments
 (0)