11"""Base library for database interaction with PgSTAC."""
2- import atexit
2+
3+ import contextlib
34import logging
45import time
6+ from dataclasses import dataclass , field
57from types import TracebackType
68from typing import Any , Generator , List , Optional , Tuple , Type , Union
79
@@ -52,37 +54,24 @@ class Settings(BaseSettings):
5254settings = Settings ()
5355
5456
57+ @dataclass
5558class PgstacDB :
5659 """Base class for interacting with PgSTAC Database."""
5760
58- def __init__ (
59- self ,
60- dsn : Optional [str ] = "" ,
61- pool : Optional [ConnectionPool ] = None ,
62- connection : Optional [Connection ] = None ,
63- commit_on_exit : bool = True ,
64- debug : bool = False ,
65- use_queue : bool = False ,
66- ) -> None :
67- """Initialize Database."""
68- self .dsn : str
69- if dsn is not None :
70- self .dsn = dsn
71- else :
72- self .dsn = ""
73- self .pool = pool
74- self .connection = connection
75- self .commit_on_exit = commit_on_exit
76- self .initial_version = "0.1.9"
77- self .debug = debug
78- self .use_queue = use_queue
79- if self .debug :
80- logging .basicConfig (level = logging .DEBUG )
61+ dsn : str
62+ commit_on_exit : bool = True
63+ debug : bool = False
64+ use_queue : bool = False
8165
82- def get_pool (self ) -> ConnectionPool :
83- """Get Database Pool."""
84- if self .pool is None :
85- self .pool = ConnectionPool (
66+ pool : ConnectionPool = field (default = None )
67+
68+ initial_version : str = field (init = False , default = "0.1.9" )
69+
70+ _pool : ConnectionPool = field (init = False )
71+
72+ def __post_init__ (self ):
73+ if not self .pool :
74+ self ._pool = ConnectionPool (
8675 conninfo = self .dsn ,
8776 min_size = settings .db_min_conn_size ,
8877 max_size = settings .db_max_conn_size ,
@@ -91,36 +80,49 @@ def get_pool(self) -> ConnectionPool:
9180 num_workers = settings .db_num_workers ,
9281 open = True ,
9382 )
94- return self .pool
9583
96- def open (self ) -> None :
97- """Open database pool connection ."""
98- self .get_pool ()
84+ def get_pool (self ) -> ConnectionPool :
85+ """Get Database Pool ."""
86+ return self .pool or self . _pool
9987
10088 def close (self ) -> None :
10189 """Close database pool connection."""
102- if self .pool is not None :
103- self .pool .close ()
90+ if self ._pool is not None :
91+ self ._pool .close ()
92+
93+ def __enter__ (self ) -> Any :
94+ """Enter used for context."""
95+ return self
96+
97+ def __exit__ (
98+ self ,
99+ exc_type : Optional [Type [BaseException ]],
100+ exc : Optional [BaseException ],
101+ traceback : Optional [TracebackType ],
102+ ) -> None :
103+ """Exit used for context."""
104+ self .close ()
104105
106+ @contextlib .contextmanager
105107 def connect (self ) -> Connection :
106108 """Return database connection."""
107109 pool = self .get_pool ()
108- if self . connection is None :
109- self . connection = pool .getconn ()
110- self . connection .autocommit = True
110+ try :
111+ conn = pool .getconn ()
112+ conn .autocommit = True
111113 if self .debug :
112- self . connection .add_notice_handler (pg_notice_handler )
113- self . connection .execute (
114+ conn .add_notice_handler (pg_notice_handler )
115+ conn .execute (
114116 "SET CLIENT_MIN_MESSAGES TO NOTICE;" ,
115117 prepare = False ,
116118 )
117119 if self .use_queue :
118- self . connection .execute (
120+ conn .execute (
119121 "SET pgstac.use_queue TO TRUE;" ,
120122 prepare = False ,
121123 )
122- atexit . register ( self . disconnect )
123- self . connection .execute (
124+
125+ conn .execute (
124126 """
125127 SELECT
126128 CASE
@@ -138,54 +140,24 @@ def connect(self) -> Connection:
138140 """ ,
139141 prepare = False ,
140142 )
141- return self .connection
143+ with conn :
144+ yield conn
145+
146+ finally :
147+ pool .putconn (conn )
142148
143149 def wait (self ) -> None :
144150 """Block until database connection is ready."""
145151 cnt : int = 0
146152 while cnt < 60 :
147153 try :
148- self .connect ()
149154 self .query ("SELECT 1;" )
150155 return None
151156 except psycopg .errors .OperationalError :
152157 time .sleep (1 )
153158 cnt += 1
154159 raise psycopg .errors .CannotConnectNow
155160
156- def disconnect (self ) -> None :
157- """Disconnect from database."""
158- try :
159- if self .connection is not None :
160- if self .commit_on_exit :
161- self .connection .commit ()
162- else :
163- self .connection .rollback ()
164- except Exception :
165- pass
166- try :
167- if self .pool is not None and self .connection is not None :
168- self .pool .putconn (self .connection )
169- except Exception :
170- pass
171-
172- self .connection = None
173- self .pool = None
174-
175- def __enter__ (self ) -> Any :
176- """Enter used for context."""
177- self .connect ()
178- return self
179-
180- def __exit__ (
181- self ,
182- exc_type : Optional [Type [BaseException ]],
183- exc : Optional [BaseException ],
184- traceback : Optional [TracebackType ],
185- ) -> None :
186- """Exit used for context."""
187- self .disconnect ()
188-
189161 @retry (
190162 stop = stop_after_attempt (settings .db_retries ),
191163 retry = retry_if_exception_type (psycopg .errors .OperationalError ),
@@ -198,30 +170,27 @@ def query(
198170 row_factory : psycopg .rows .BaseRowFactory = psycopg .rows .tuple_row ,
199171 ) -> Generator :
200172 """Query the database with parameters."""
201- conn = self .connect ()
202- try :
203- with conn .cursor (row_factory = row_factory ) as cursor :
204- if args is None :
205- rows = cursor .execute (query , prepare = False )
206- else :
207- rows = cursor .execute (query , args )
208- if rows :
209- for row in rows :
210- yield row
211- else :
212- yield None
213- except psycopg .errors .OperationalError as e :
214- # If we get an operational error check the pool and retry
215- logger .warning (f"OPERATIONAL ERROR: { e } " )
216- if self .pool is None :
217- self .get_pool ()
218- else :
219- self .pool .check ()
220- raise e
221- except psycopg .errors .DatabaseError as e :
222- if conn is not None :
223- conn .rollback ()
224- raise e
173+ with self .connect () as conn :
174+ try :
175+ with conn .cursor (row_factory = row_factory ) as cursor :
176+ if args is None :
177+ rows = cursor .execute (query , prepare = False )
178+ else :
179+ rows = cursor .execute (query , args )
180+ if rows :
181+ for row in rows :
182+ yield row
183+ else :
184+ yield None
185+ except psycopg .errors .OperationalError as e :
186+ # If we get an operational error check the pool and retry
187+ logger .warning (f"OPERATIONAL ERROR: { e } " )
188+ self ._pool .check ()
189+ raise e
190+ except psycopg .errors .DatabaseError as e :
191+ if conn is not None :
192+ conn .rollback ()
193+ raise e
225194
226195 def query_one (self , * args : Any , ** kwargs : Any ) -> Union [Tuple , str , None ]:
227196 """Return results from a query that returns a single row."""
@@ -238,10 +207,9 @@ def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]:
238207
239208 def run_queued (self ) -> str :
240209 try :
241- self .connect ().execute ("""
242- CALL run_queued_queries();
243- """ )
244- return "Ran Queued Queries"
210+ with self .connect () as conn :
211+ conn .execute ("CALL run_queued_queries();" )
212+ return "Ran Queued Queries"
245213 except Exception as e :
246214 return f"Error Running Queued Queries: { e } "
247215
@@ -262,8 +230,6 @@ def version(self) -> Optional[str]:
262230 return version
263231 except psycopg .errors .UndefinedTable :
264232 logger .debug ("PgSTAC is not installed." )
265- if self .connection is not None :
266- self .connection .rollback ()
267233 return None
268234
269235 @property
@@ -280,13 +246,13 @@ def pg_version(self) -> str:
280246 if isinstance (version , str ):
281247 if int (version ) < 130000 :
282248 major , minor , patch = tuple (
283- map (int , [version [i : i + 2 ] for i in range (0 , len (version ), 2 )]),
249+ map (int , [version [i : i + 2 ] for i in range (0 , len (version ), 2 )]),
284250 )
285- raise Exception (f"PgSTAC requires PostgreSQL 13+, current version is: { major } .{ minor } .{ patch } " ) # noqa: E501
251+ raise Exception (
252+ f"PgSTAC requires PostgreSQL 13+, current version is: { major } .{ minor } .{ patch } " ,
253+ ) # noqa: E501
286254 return version
287255 else :
288- if self .connection is not None :
289- self .connection .rollback ()
290256 raise Exception ("Could not find PG version." )
291257
292258 def func (self , function_name : str , * args : Any ) -> Generator :
0 commit comments