Skip to content

Commit 1514154

Browse files
committed
refactor read/write connection pool
1 parent db9d81c commit 1514154

File tree

11 files changed

+102
-86
lines changed

11 files changed

+102
-86
lines changed

.github/workflows/cicd.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ jobs:
9393
POSTGRES_USER: username
9494
POSTGRES_PASS: password
9595
POSTGRES_DBNAME: postgis
96-
POSTGRES_HOST_READER: localhost
97-
POSTGRES_HOST_WRITER: localhost
96+
POSTGRES_HOST: localhost
9897
POSTGRES_PORT: 5432
9998
PGUSER: username
10099
PGPASSWORD: password

CHANGES.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,25 @@
44

55
### Changed
66

7+
- rename `POSTGRES_HOST_READER` to `POSTGRES_HOST` in config **breaking change**
8+
- rename `reader_connection_string` to `connection_string` in `PostgresSettings` class **breaking change**
79
- add `ENABLE_TRANSACTIONS_EXTENSIONS` env variable to enable `transaction` extensions
810
- disable transaction and bulk_transactions extensions by default **breaking change**
911
- update `stac-fastapi-*` version requirements to `>=5.2,<6.0`
1012
- add pgstac health-check in `/_mgmt/health`
1113

14+
### Added
15+
16+
- add `write_connection_pool` option in `stac_fastapi.pgstac.db.connect_to_db` function
17+
- add `write_postgres_settings` option in `stac_fastapi.pgstac.db.connect_to_db` function to set specific settings for the `writer` DB connection pool
18+
19+
### removed
20+
21+
- `stac_fastapi.pgstac.db.DB` class
22+
- `POSTGRES_HOST_WRITER` in config
23+
- `writer_connection_string` in `PostgresSettings` class
24+
- `testing_connection_string` in `PostgresSettings` class
25+
1226
## [5.0.2] - 2025-04-07
1327

1428
### Fixed

docker-compose.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ services:
1111
- POSTGRES_USER=username
1212
- POSTGRES_PASS=password
1313
- POSTGRES_DBNAME=postgis
14-
- POSTGRES_HOST_READER=database
15-
- POSTGRES_HOST_WRITER=database
14+
- POSTGRES_HOST=database
1615
- POSTGRES_PORT=5432
1716
- WEB_CONCURRENCY=10
1817
- VSI_CACHE=TRUE

stac_fastapi/pgstac/app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,12 @@
9595

9696
application_extensions = []
9797

98-
if os.environ.get("ENABLE_TRANSACTIONS_EXTENSIONS", "").lower() in ["yes", "true", "1"]:
98+
with_transactions = os.environ.get("ENABLE_TRANSACTIONS_EXTENSIONS", "").lower() in [
99+
"yes",
100+
"true",
101+
"1",
102+
]
103+
if with_transactions:
99104
application_extensions.append(
100105
TransactionExtension(
101106
client=TransactionsClient(),
@@ -150,7 +155,7 @@
150155
@asynccontextmanager
151156
async def lifespan(app: FastAPI):
152157
"""FastAPI Lifespan."""
153-
await connect_to_db(app)
158+
await connect_to_db(app, add_write_connection_pool=with_transactions)
154159
yield
155160
await close_db_connection(app)
156161

stac_fastapi/pgstac/config.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ class PostgresSettings(BaseSettings):
5454
Attributes:
5555
postgres_user: postgres username.
5656
postgres_pass: postgres password.
57-
postgres_host_reader: hostname for the reader connection.
58-
postgres_host_writer: hostname for the writer connection.
57+
postgres_host: hostname for the connection.
5958
postgres_port: database port.
6059
postgres_dbname: database name.
6160
use_api_hydrate: perform hydration of stac items within stac-fastapi.
@@ -64,8 +63,7 @@ class PostgresSettings(BaseSettings):
6463

6564
postgres_user: str
6665
postgres_pass: str
67-
postgres_host_reader: str
68-
postgres_host_writer: str
66+
postgres_host: str
6967
postgres_port: int
7068
postgres_dbname: str
7169

@@ -79,19 +77,9 @@ class PostgresSettings(BaseSettings):
7977
model_config = {"env_file": ".env", "extra": "ignore"}
8078

8179
@property
82-
def reader_connection_string(self):
80+
def connection_string(self):
8381
"""Create reader psql connection string."""
84-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_reader}:{self.postgres_port}/{self.postgres_dbname}"
85-
86-
@property
87-
def writer_connection_string(self):
88-
"""Create writer psql connection string."""
89-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/{self.postgres_dbname}"
90-
91-
@property
92-
def testing_connection_string(self):
93-
"""Create testing psql connection string."""
94-
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host_writer}:{self.postgres_port}/pgstactestdb"
82+
return f"postgresql://{self.postgres_user}:{quote(self.postgres_pass)}@{self.postgres_host}:{self.postgres_port}/{self.postgres_dbname}"
9583

9684

9785
class Settings(ApiSettings):

stac_fastapi/pgstac/db.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
Union,
1414
)
1515

16-
import attr
1716
import orjson
18-
from asyncpg import Connection, exceptions
17+
from asyncpg import Connection, Pool, exceptions
1918
from buildpg import V, asyncpg, render
20-
from fastapi import FastAPI, Request
19+
from fastapi import FastAPI, HTTPException, Request
2120
from stac_fastapi.types.errors import (
2221
ConflictError,
2322
DatabaseError,
@@ -47,33 +46,46 @@ async def con_init(conn):
4746
ConnectionGetter = Callable[[Request, Literal["r", "w"]], AsyncIterator[Connection]]
4847

4948

49+
async def _create_pool(settings: PostgresSettings) -> Pool:
50+
"""Create a connection pool."""
51+
return await asyncpg.create_pool(
52+
settings.connection_string,
53+
min_size=settings.db_min_conn_size,
54+
max_size=settings.db_max_conn_size,
55+
max_queries=settings.db_max_queries,
56+
max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime,
57+
init=con_init,
58+
server_settings=settings.server_settings.model_dump(),
59+
)
60+
61+
5062
async def connect_to_db(
5163
app: FastAPI,
5264
get_conn: Optional[ConnectionGetter] = None,
5365
postgres_settings: Optional[PostgresSettings] = None,
66+
add_write_connection_pool: bool = False,
67+
write_postgres_settings: Optional[PostgresSettings] = None,
5468
) -> None:
5569
"""Create connection pools & connection retriever on application."""
56-
app_settings = app.state.settings
57-
5870
if not postgres_settings:
5971
postgres_settings = PostgresSettings()
6072

61-
if app_settings.testing:
62-
readpool = writepool = postgres_settings.testing_connection_string
63-
else:
64-
readpool = postgres_settings.reader_connection_string
65-
writepool = postgres_settings.writer_connection_string
73+
app.state.readpool = await _create_pool(postgres_settings)
74+
75+
if add_write_connection_pool:
76+
if not write_postgres_settings:
77+
write_postgres_settings = PostgresSettings()
78+
79+
app.state.writepool = await _create_pool(write_postgres_settings)
6680

67-
db = DB()
68-
app.state.readpool = await db.create_pool(readpool, postgres_settings)
69-
app.state.writepool = await db.create_pool(writepool, postgres_settings)
7081
app.state.get_connection = get_conn if get_conn else get_connection
7182

7283

7384
async def close_db_connection(app: FastAPI) -> None:
7485
"""Close connection."""
7586
await app.state.readpool.close()
76-
await app.state.writepool.close()
87+
if pool := getattr(app.state, "writepool", None):
88+
await pool.close()
7789

7890

7991
@asynccontextmanager
@@ -82,7 +94,15 @@ async def get_connection(
8294
readwrite: Literal["r", "w"] = "r",
8395
) -> AsyncIterator[Connection]:
8496
"""Retrieve connection from database conection pool."""
85-
pool = request.app.state.writepool if readwrite == "w" else request.app.state.readpool
97+
pool = request.app.state.readpool
98+
if readwrite == "w":
99+
pool = getattr(request.app.state, "writepool", None)
100+
if not pool:
101+
raise HTTPException(
102+
status_code=500,
103+
detail="Could not find connection pool for write operations",
104+
)
105+
86106
with translate_pgstac_errors():
87107
async with pool.acquire() as conn:
88108
yield conn
@@ -131,25 +151,3 @@ def translate_pgstac_errors() -> Generator[None, None, None]:
131151
raise DatabaseError from e
132152
except exceptions.ForeignKeyViolationError as e:
133153
raise ForeignKeyError from e
134-
135-
136-
@attr.s
137-
class DB:
138-
"""DB class that can be used with context manager."""
139-
140-
connection_string = attr.ib(default=None)
141-
_pool = attr.ib(default=None)
142-
_connection = attr.ib(default=None)
143-
144-
async def create_pool(self, connection_string: str, settings):
145-
"""Create a connection pool."""
146-
pool = await asyncpg.create_pool(
147-
connection_string,
148-
min_size=settings.db_min_conn_size,
149-
max_size=settings.db_max_conn_size,
150-
max_queries=settings.db_max_queries,
151-
max_inactive_connection_lifetime=settings.db_max_inactive_conn_lifetime,
152-
init=con_init,
153-
server_settings=settings.server_settings.model_dump(),
154-
)
155-
return pool

stac_fastapi/pgstac/extensions/filter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ async def get_queryables(
2525
under OGC CQL but it is allowed by the STAC API Filter Extension
2626
https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables
2727
"""
28-
pool = request.app.state.readpool
29-
30-
async with pool.acquire() as conn:
28+
async with request.app.state.get_connection(request, "r") as conn:
3129
q, p = render(
3230
"""
3331
SELECT * FROM get_queryables(:collection::text);

tests/api/test_api.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,7 @@ async def get_collection(
742742
postgres_settings = PostgresSettings(
743743
postgres_user=database.user,
744744
postgres_pass=database.password,
745-
postgres_host_reader=database.host,
746-
postgres_host_writer=database.host,
745+
postgres_host=database.host,
747746
postgres_port=database.port,
748747
postgres_dbname=database.dbname,
749748
)
@@ -770,7 +769,12 @@ async def get_collection(
770769
collections_get_request_model=collection_search_extension.GET,
771770
)
772771
app = api.app
773-
await connect_to_db(app, postgres_settings=postgres_settings)
772+
await connect_to_db(
773+
app,
774+
postgres_settings=postgres_settings,
775+
add_write_connection_pool=True,
776+
write_postgres_settings=postgres_settings,
777+
)
774778
try:
775779
async with AsyncClient(transport=ASGITransport(app=app)) as client:
776780
response = await client.post(
@@ -812,8 +816,7 @@ async def test_no_extension(
812816
postgres_settings = PostgresSettings(
813817
postgres_user=database.user,
814818
postgres_pass=database.password,
815-
postgres_host_reader=database.host,
816-
postgres_host_writer=database.host,
819+
postgres_host=database.host,
817820
postgres_port=database.port,
818821
postgres_dbname=database.dbname,
819822
)
@@ -826,7 +829,12 @@ async def test_no_extension(
826829
search_post_request_model=post_request_model,
827830
)
828831
app = api.app
829-
await connect_to_db(app, postgres_settings=postgres_settings)
832+
await connect_to_db(
833+
app,
834+
postgres_settings=postgres_settings,
835+
add_write_connection_pool=True,
836+
write_postgres_settings=postgres_settings,
837+
)
830838
try:
831839
async with AsyncClient(transport=ASGITransport(app=app)) as client:
832840
landing = await client.get("http://test/")

tests/clients/test_postgres.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,7 @@ async def test_db_setup_works_with_env_vars(api_client, database, monkeypatch):
529529
"""Test that the application starts successfully if the POSTGRES_* environment variables are set"""
530530
monkeypatch.setenv("POSTGRES_USER", database.user)
531531
monkeypatch.setenv("POSTGRES_PASS", database.password)
532-
monkeypatch.setenv("POSTGRES_HOST_READER", database.host)
533-
monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host)
532+
monkeypatch.setenv("POSTGRES_HOST", database.host)
534533
monkeypatch.setenv("POSTGRES_PORT", str(database.port))
535534
monkeypatch.setenv("POSTGRES_DBNAME", database.dbname)
536535

@@ -567,8 +566,7 @@ async def app(self, api_client, database):
567566
postgres_settings = PostgresSettings(
568567
postgres_user=database.user,
569568
postgres_pass=database.password,
570-
postgres_host_reader=database.host,
571-
postgres_host_writer=database.host,
569+
postgres_host=database.host,
572570
postgres_port=database.port,
573571
postgres_dbname=database.dbname,
574572
)

tests/conftest.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,19 @@ async def app(api_client, database):
202202
postgres_settings = PostgresSettings(
203203
postgres_user=database.user,
204204
postgres_pass=database.password,
205-
postgres_host_reader=database.host,
206-
postgres_host_writer=database.host,
205+
postgres_host=database.host,
207206
postgres_port=database.port,
208207
postgres_dbname=database.dbname,
209208
)
210209
logger.info("Creating app Fixture")
211210
time.time()
212211
app = api_client.app
213-
await connect_to_db(app, postgres_settings=postgres_settings)
212+
await connect_to_db(
213+
app,
214+
postgres_settings=postgres_settings,
215+
add_write_connection_pool=True,
216+
write_postgres_settings=postgres_settings,
217+
)
214218

215219
yield app
216220

@@ -306,14 +310,18 @@ async def app_no_ext(database):
306310
postgres_settings = PostgresSettings(
307311
postgres_user=database.user,
308312
postgres_pass=database.password,
309-
postgres_host_reader=database.host,
310-
postgres_host_writer=database.host,
313+
postgres_host=database.host,
311314
postgres_port=database.port,
312315
postgres_dbname=database.dbname,
313316
)
314317
logger.info("Creating app Fixture")
315318
time.time()
316-
await connect_to_db(api_client_no_ext.app, postgres_settings=postgres_settings)
319+
await connect_to_db(
320+
api_client_no_ext.app,
321+
postgres_settings=postgres_settings,
322+
add_write_connection_pool=True,
323+
write_postgres_settings=postgres_settings,
324+
)
317325
yield api_client_no_ext.app
318326
await close_db_connection(api_client_no_ext.app)
319327

@@ -343,14 +351,17 @@ async def app_no_transaction(database):
343351
postgres_settings = PostgresSettings(
344352
postgres_user=database.user,
345353
postgres_pass=database.password,
346-
postgres_host_reader=database.host,
347-
postgres_host_writer=database.host,
354+
postgres_host=database.host,
348355
postgres_port=database.port,
349356
postgres_dbname=database.dbname,
350357
)
351358
logger.info("Creating app Fixture")
352359
time.time()
353-
await connect_to_db(api.app, postgres_settings=postgres_settings)
360+
await connect_to_db(
361+
api.app,
362+
postgres_settings=postgres_settings,
363+
add_write_connection_pool=False,
364+
)
354365
yield api.app
355366
await close_db_connection(api.app)
356367

@@ -371,8 +382,7 @@ async def default_app(database, monkeypatch):
371382
"""Test default stac-fastapi-pgstac application."""
372383
monkeypatch.setenv("POSTGRES_USER", database.user)
373384
monkeypatch.setenv("POSTGRES_PASS", database.password)
374-
monkeypatch.setenv("POSTGRES_HOST_READER", database.host)
375-
monkeypatch.setenv("POSTGRES_HOST_WRITER", database.host)
385+
monkeypatch.setenv("POSTGRES_HOST", database.host)
376386
monkeypatch.setenv("POSTGRES_PORT", str(database.port))
377387
monkeypatch.setenv("POSTGRES_DBNAME", database.dbname)
378388
monkeypatch.delenv("ENABLED_EXTENSIONS", raising=False)
@@ -383,7 +393,7 @@ async def default_app(database, monkeypatch):
383393

384394
from stac_fastapi.pgstac.app import app
385395

386-
await connect_to_db(app)
396+
await connect_to_db(app, add_write_connection_pool=True)
387397
yield app
388398
await close_db_connection(app)
389399

0 commit comments

Comments
 (0)