1313 Union ,
1414)
1515
16- import attr
1716import orjson
18- from asyncpg import Connection , exceptions
17+ from asyncpg import Connection , Pool , exceptions
1918from buildpg import V , asyncpg , render
20- from fastapi import FastAPI , Request
19+ from fastapi import FastAPI , HTTPException , Request
2120from stac_fastapi .types .errors import (
2221 ConflictError ,
2322 DatabaseError ,
@@ -47,33 +46,46 @@ async def con_init(conn):
4746ConnectionGetter = Callable [[Request , Literal ["r" , "w" ]], AsyncIterator [Connection ]]
4847
4948
49+ async def _create_pool (settings : PostgresSettings ) -> Pool :
50+ """Create a connection pool."""
51+ return await asyncpg .create_pool (
52+ settings .connection_string ,
53+ min_size = settings .db_min_conn_size ,
54+ max_size = settings .db_max_conn_size ,
55+ max_queries = settings .db_max_queries ,
56+ max_inactive_connection_lifetime = settings .db_max_inactive_conn_lifetime ,
57+ init = con_init ,
58+ server_settings = settings .server_settings .model_dump (),
59+ )
60+
61+
5062async def connect_to_db (
5163 app : FastAPI ,
5264 get_conn : Optional [ConnectionGetter ] = None ,
5365 postgres_settings : Optional [PostgresSettings ] = None ,
66+ add_write_connection_pool : bool = False ,
67+ write_postgres_settings : Optional [PostgresSettings ] = None ,
5468) -> None :
5569 """Create connection pools & connection retriever on application."""
56- app_settings = app .state .settings
57-
5870 if not postgres_settings :
5971 postgres_settings = PostgresSettings ()
6072
61- if app_settings .testing :
62- readpool = writepool = postgres_settings .testing_connection_string
63- else :
64- readpool = postgres_settings .reader_connection_string
65- writepool = postgres_settings .writer_connection_string
73+ app .state .readpool = await _create_pool (postgres_settings )
74+
75+ if add_write_connection_pool :
76+ if not write_postgres_settings :
77+ write_postgres_settings = PostgresSettings ()
78+
79+ app .state .writepool = await _create_pool (write_postgres_settings )
6680
67- db = DB ()
68- app .state .readpool = await db .create_pool (readpool , postgres_settings )
69- app .state .writepool = await db .create_pool (writepool , postgres_settings )
7081 app .state .get_connection = get_conn if get_conn else get_connection
7182
7283
7384async def close_db_connection (app : FastAPI ) -> None :
7485 """Close connection."""
7586 await app .state .readpool .close ()
76- await app .state .writepool .close ()
87+ if pool := getattr (app .state , "writepool" , None ):
88+ await pool .close ()
7789
7890
7991@asynccontextmanager
@@ -82,7 +94,15 @@ async def get_connection(
8294 readwrite : Literal ["r" , "w" ] = "r" ,
8395) -> AsyncIterator [Connection ]:
8496 """Retrieve connection from database conection pool."""
85- pool = request .app .state .writepool if readwrite == "w" else request .app .state .readpool
97+ pool = request .app .state .readpool
98+ if readwrite == "w" :
99+ pool = getattr (request .app .state , "writepool" , None )
100+ if not pool :
101+ raise HTTPException (
102+ status_code = 500 ,
103+ detail = "Could not find connection pool for write operations" ,
104+ )
105+
86106 with translate_pgstac_errors ():
87107 async with pool .acquire () as conn :
88108 yield conn
@@ -131,25 +151,3 @@ def translate_pgstac_errors() -> Generator[None, None, None]:
131151 raise DatabaseError from e
132152 except exceptions .ForeignKeyViolationError as e :
133153 raise ForeignKeyError from e
134-
135-
136- @attr .s
137- class DB :
138- """DB class that can be used with context manager."""
139-
140- connection_string = attr .ib (default = None )
141- _pool = attr .ib (default = None )
142- _connection = attr .ib (default = None )
143-
144- async def create_pool (self , connection_string : str , settings ):
145- """Create a connection pool."""
146- pool = await asyncpg .create_pool (
147- connection_string ,
148- min_size = settings .db_min_conn_size ,
149- max_size = settings .db_max_conn_size ,
150- max_queries = settings .db_max_queries ,
151- max_inactive_connection_lifetime = settings .db_max_inactive_conn_lifetime ,
152- init = con_init ,
153- server_settings = settings .server_settings .model_dump (),
154- )
155- return pool
0 commit comments