Skip to content

Commit c689b3f

Browse files
committed
refactor pgstacDb class
1 parent 12bcdd5 commit c689b3f

File tree

11 files changed

+548
-545
lines changed

11 files changed

+548
-545
lines changed

src/pypgstac/examples/load_queryables_example.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717

1818
def load_for_specific_collections(
19-
cli, sample_file, collection_ids, delete_missing=False,
19+
cli,
20+
sample_file,
21+
collection_ids,
22+
delete_missing=False,
2023
):
2124
"""Load queryables for specific collections.
2225
@@ -27,7 +30,9 @@ def load_for_specific_collections(
2730
delete_missing: If True, delete properties not present in the file
2831
"""
2932
cli.load_queryables(
30-
str(sample_file), collection_ids=collection_ids, delete_missing=delete_missing,
33+
str(sample_file),
34+
collection_ids=collection_ids,
35+
delete_missing=delete_missing,
3136
)
3237

3338

@@ -57,7 +62,10 @@ def main():
5762
# Example of loading for specific collections with delete_missing=True
5863
# This will delete properties not present in the file, but only for the specified collections
5964
load_for_specific_collections(
60-
cli, sample_file, ["landsat-8", "sentinel-2"], delete_missing=True,
65+
cli,
66+
sample_file,
67+
["landsat-8", "sentinel-2"],
68+
delete_missing=True,
6169
)
6270

6371

src/pypgstac/pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,9 @@ select = [
9999
"PLE",
100100
# "PLR",
101101
"PLW",
102-
"COM", # flake8-commas
103102
]
104103
ignore = [
105-
# "E501", # line too long, handled by black
104+
"E501", # line too long, handled by black
106105
"B008", # do not perform function calls in argument defaults
107106
"C901", # too complex
108107
"B905",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""pyPgSTAC Version."""
2+
23
from pypgstac.version import __version__
34

45
__all__ = ["__version__"]

src/pypgstac/src/pypgstac/db.py

Lines changed: 77 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Base library for database interaction with PgSTAC."""
2-
import atexit
2+
3+
import contextlib
34
import logging
45
import time
6+
from dataclasses import dataclass, field
57
from types import TracebackType
68
from typing import Any, Generator, List, Optional, Tuple, Type, Union
79

@@ -52,37 +54,24 @@ class Settings(BaseSettings):
5254
settings = Settings()
5355

5456

57+
@dataclass
5558
class PgstacDB:
5659
"""Base class for interacting with PgSTAC Database."""
5760

58-
def __init__(
59-
self,
60-
dsn: Optional[str] = "",
61-
pool: Optional[ConnectionPool] = None,
62-
connection: Optional[Connection] = None,
63-
commit_on_exit: bool = True,
64-
debug: bool = False,
65-
use_queue: bool = False,
66-
) -> None:
67-
"""Initialize Database."""
68-
self.dsn: str
69-
if dsn is not None:
70-
self.dsn = dsn
71-
else:
72-
self.dsn = ""
73-
self.pool = pool
74-
self.connection = connection
75-
self.commit_on_exit = commit_on_exit
76-
self.initial_version = "0.1.9"
77-
self.debug = debug
78-
self.use_queue = use_queue
79-
if self.debug:
80-
logging.basicConfig(level=logging.DEBUG)
61+
dsn: str
62+
commit_on_exit: bool = True
63+
debug: bool = False
64+
use_queue: bool = False
8165

82-
def get_pool(self) -> ConnectionPool:
83-
"""Get Database Pool."""
84-
if self.pool is None:
85-
self.pool = ConnectionPool(
66+
pool: ConnectionPool = field(default=None)
67+
68+
initial_version: str = field(init=False, default="0.1.9")
69+
70+
_pool: ConnectionPool = field(init=False)
71+
72+
def __post_init__(self):
73+
if not self.pool:
74+
self._pool = ConnectionPool(
8675
conninfo=self.dsn,
8776
min_size=settings.db_min_conn_size,
8877
max_size=settings.db_max_conn_size,
@@ -91,36 +80,49 @@ def get_pool(self) -> ConnectionPool:
9180
num_workers=settings.db_num_workers,
9281
open=True,
9382
)
94-
return self.pool
9583

96-
def open(self) -> None:
97-
"""Open database pool connection."""
98-
self.get_pool()
84+
def get_pool(self) -> ConnectionPool:
85+
"""Get Database Pool."""
86+
return self.pool or self._pool
9987

10088
def close(self) -> None:
10189
"""Close database pool connection."""
102-
if self.pool is not None:
103-
self.pool.close()
90+
if self._pool is not None:
91+
self._pool.close()
92+
93+
def __enter__(self) -> Any:
94+
"""Enter used for context."""
95+
return self
96+
97+
def __exit__(
98+
self,
99+
exc_type: Optional[Type[BaseException]],
100+
exc: Optional[BaseException],
101+
traceback: Optional[TracebackType],
102+
) -> None:
103+
"""Exit used for context."""
104+
self.close()
104105

106+
@contextlib.contextmanager
105107
def connect(self) -> Connection:
106108
"""Return database connection."""
107109
pool = self.get_pool()
108-
if self.connection is None:
109-
self.connection = pool.getconn()
110-
self.connection.autocommit = True
110+
try:
111+
conn = pool.getconn()
112+
conn.autocommit = True
111113
if self.debug:
112-
self.connection.add_notice_handler(pg_notice_handler)
113-
self.connection.execute(
114+
conn.add_notice_handler(pg_notice_handler)
115+
conn.execute(
114116
"SET CLIENT_MIN_MESSAGES TO NOTICE;",
115117
prepare=False,
116118
)
117119
if self.use_queue:
118-
self.connection.execute(
120+
conn.execute(
119121
"SET pgstac.use_queue TO TRUE;",
120122
prepare=False,
121123
)
122-
atexit.register(self.disconnect)
123-
self.connection.execute(
124+
125+
conn.execute(
124126
"""
125127
SELECT
126128
CASE
@@ -138,54 +140,24 @@ def connect(self) -> Connection:
138140
""",
139141
prepare=False,
140142
)
141-
return self.connection
143+
with conn:
144+
yield conn
145+
146+
finally:
147+
pool.putconn(conn)
142148

143149
def wait(self) -> None:
144150
"""Block until database connection is ready."""
145151
cnt: int = 0
146152
while cnt < 60:
147153
try:
148-
self.connect()
149154
self.query("SELECT 1;")
150155
return None
151156
except psycopg.errors.OperationalError:
152157
time.sleep(1)
153158
cnt += 1
154159
raise psycopg.errors.CannotConnectNow
155160

156-
def disconnect(self) -> None:
157-
"""Disconnect from database."""
158-
try:
159-
if self.connection is not None:
160-
if self.commit_on_exit:
161-
self.connection.commit()
162-
else:
163-
self.connection.rollback()
164-
except Exception:
165-
pass
166-
try:
167-
if self.pool is not None and self.connection is not None:
168-
self.pool.putconn(self.connection)
169-
except Exception:
170-
pass
171-
172-
self.connection = None
173-
self.pool = None
174-
175-
def __enter__(self) -> Any:
176-
"""Enter used for context."""
177-
self.connect()
178-
return self
179-
180-
def __exit__(
181-
self,
182-
exc_type: Optional[Type[BaseException]],
183-
exc: Optional[BaseException],
184-
traceback: Optional[TracebackType],
185-
) -> None:
186-
"""Exit used for context."""
187-
self.disconnect()
188-
189161
@retry(
190162
stop=stop_after_attempt(settings.db_retries),
191163
retry=retry_if_exception_type(psycopg.errors.OperationalError),
@@ -198,30 +170,27 @@ def query(
198170
row_factory: psycopg.rows.BaseRowFactory = psycopg.rows.tuple_row,
199171
) -> Generator:
200172
"""Query the database with parameters."""
201-
conn = self.connect()
202-
try:
203-
with conn.cursor(row_factory=row_factory) as cursor:
204-
if args is None:
205-
rows = cursor.execute(query, prepare=False)
206-
else:
207-
rows = cursor.execute(query, args)
208-
if rows:
209-
for row in rows:
210-
yield row
211-
else:
212-
yield None
213-
except psycopg.errors.OperationalError as e:
214-
# If we get an operational error check the pool and retry
215-
logger.warning(f"OPERATIONAL ERROR: {e}")
216-
if self.pool is None:
217-
self.get_pool()
218-
else:
219-
self.pool.check()
220-
raise e
221-
except psycopg.errors.DatabaseError as e:
222-
if conn is not None:
223-
conn.rollback()
224-
raise e
173+
with self.connect() as conn:
174+
try:
175+
with conn.cursor(row_factory=row_factory) as cursor:
176+
if args is None:
177+
rows = cursor.execute(query, prepare=False)
178+
else:
179+
rows = cursor.execute(query, args)
180+
if rows:
181+
for row in rows:
182+
yield row
183+
else:
184+
yield None
185+
except psycopg.errors.OperationalError as e:
186+
# If we get an operational error check the pool and retry
187+
logger.warning(f"OPERATIONAL ERROR: {e}")
188+
self._pool.check()
189+
raise e
190+
except psycopg.errors.DatabaseError as e:
191+
if conn is not None:
192+
conn.rollback()
193+
raise e
225194

226195
def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]:
227196
"""Return results from a query that returns a single row."""
@@ -238,10 +207,9 @@ def query_one(self, *args: Any, **kwargs: Any) -> Union[Tuple, str, None]:
238207

239208
def run_queued(self) -> str:
240209
try:
241-
self.connect().execute("""
242-
CALL run_queued_queries();
243-
""")
244-
return "Ran Queued Queries"
210+
with self.connect() as conn:
211+
conn.execute("CALL run_queued_queries();")
212+
return "Ran Queued Queries"
245213
except Exception as e:
246214
return f"Error Running Queued Queries: {e}"
247215

@@ -262,8 +230,6 @@ def version(self) -> Optional[str]:
262230
return version
263231
except psycopg.errors.UndefinedTable:
264232
logger.debug("PgSTAC is not installed.")
265-
if self.connection is not None:
266-
self.connection.rollback()
267233
return None
268234

269235
@property
@@ -280,13 +246,13 @@ def pg_version(self) -> str:
280246
if isinstance(version, str):
281247
if int(version) < 130000:
282248
major, minor, patch = tuple(
283-
map(int, [version[i:i + 2] for i in range(0, len(version), 2)]),
249+
map(int, [version[i : i + 2] for i in range(0, len(version), 2)]),
284250
)
285-
raise Exception(f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}") # noqa: E501
251+
raise Exception(
252+
f"PgSTAC requires PostgreSQL 13+, current version is: {major}.{minor}.{patch}",
253+
) # noqa: E501
286254
return version
287255
else:
288-
if self.connection is not None:
289-
self.connection.rollback()
290256
raise Exception("Could not find PG version.")
291257

292258
def func(self, function_name: str, *args: Any) -> Generator:

0 commit comments

Comments
 (0)