Skip to content

Commit 4f04b70

Browse files
committed
Modify pypgstac to pass mypy test
1 parent 7269c8f commit 4f04b70

File tree

1 file changed

+113
-65
lines changed

1 file changed

+113
-65
lines changed

pypgstac/pypgstac/pypgstac.py

Lines changed: 113 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import asyncio
2+
from io import TextIOWrapper
23
import os
3-
from typing import List
4+
import time
5+
from typing import Any, AsyncGenerator, Dict, Iterable, List, Optional, TypeVar, Union
46

57
import asyncpg
8+
from asyncpg.connection import Connection
69
import typer
710
import orjson
811
from smart_open import open
@@ -21,11 +24,11 @@
2124
migrations_dir = os.path.join(dirname, "migrations")
2225

2326

24-
def pglogger(conn, message):
27+
def pglogger(message: str) -> None:
2528
logging.debug(message)
2629

2730

28-
async def con_init(conn):
31+
async def con_init(conn: Connection) -> None:
2932
"""Use orjson for json returns."""
3033
await conn.set_type_codec(
3134
"json",
@@ -42,33 +45,36 @@ async def con_init(conn):
4245

4346

4447
class DB:
45-
pg_connection_string = None
46-
connection = None
48+
pg_connection_string: Optional[str] = None
49+
connection: Optional[Connection] = None
4750

48-
def __init__(self, pg_connection_string: str = None):
51+
def __init__(self, pg_connection_string: Optional[str] = None) -> None:
4952
self.pg_connection_string = pg_connection_string
5053

51-
async def create_connection(self):
52-
self.connection = await asyncpg.connect(
54+
async def create_connection(self) -> Connection:
55+
connection: Connection = await asyncpg.connect(
5356
self.pg_connection_string,
5457
server_settings={
5558
"search_path": "pgstac,public",
5659
"application_name": "pypgstac",
5760
},
5861
)
59-
await con_init(self.connection)
62+
await con_init(connection)
63+
self.connection = connection
6064
return self.connection
6165

62-
async def __aenter__(self):
66+
async def __aenter__(self) -> Connection:
6367
if self.connection is None:
6468
await self.create_connection()
69+
assert self.connection is not None
6570
return self.connection
6671

67-
async def __aexit__(self, exc_type, exc_val, exc_tb):
68-
await self.connection.close()
72+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
73+
if self.connection:
74+
await self.connection.close()
6975

7076

71-
async def run_migration(dsn: str = None):
77+
async def run_migration(dsn: Optional[str] = None) -> str:
7278
conn = await asyncpg.connect(dsn=dsn)
7379
async with conn.transaction():
7480
try:
@@ -102,27 +108,33 @@ async def run_migration(dsn: str = None):
102108
f"Pypgstac does not have a migration from {oldversion} to {version} ({migration_file})"
103109
)
104110

105-
with open(migration_file) as f:
106-
migration_sql = f.read()
107-
logging.debug(migration_sql)
108-
async with conn.transaction():
109-
conn.add_log_listener(pglogger)
110-
await conn.execute(migration_sql)
111-
await conn.execute(
112-
"""
113-
INSERT INTO pgstac.migrations (version)
114-
VALUES ($1);
115-
""",
116-
version,
117-
)
118-
119-
await conn.close()
111+
open_migration_file = open(migration_file)
112+
if isinstance(open_migration_file, TextIOWrapper):
113+
with open_migration_file as f:
114+
migration_sql = f.read()
115+
logging.debug(migration_sql)
116+
async with conn.transaction():
117+
conn.add_log_listener(pglogger)
118+
await conn.execute(migration_sql)
119+
await conn.execute(
120+
"""
121+
INSERT INTO pgstac.migrations (version)
122+
VALUES ($1);
123+
""",
124+
version,
125+
)
126+
127+
await conn.close()
128+
else:
129+
raise IOError(f"Unable to open {migration_file}")
120130
return version
121131

122132

123133
@app.command()
124-
def migrate(dsn: str = None):
125-
typer.echo(asyncio.run(run_migration(dsn)))
134+
def migrate(dsn: Optional[str] = None) -> None:
135+
"""Migrate a pgstac database"""
136+
version = asyncio.run(run_migration(dsn))
137+
typer.echo(f'pgstac version {version}')
126138

127139

128140
class loadopt(str, Enum):
@@ -135,35 +147,42 @@ class tables(str, Enum):
135147
items = "items"
136148
collections = "collections"
137149

150+
# Types of iterable that load_iterator can support
151+
T = TypeVar('T', Iterable[bytes], Iterable[Dict[str, Any]], Iterable[str])
152+
138153

139-
async def aiter(list: List):
140-
for i in list:
141-
if isinstance(i, bytes):
142-
i = i.decode("utf-8")
143-
elif isinstance(i, dict):
144-
i = orjson.dumps(i).decode("utf-8")
145-
if isinstance(i, str):
146-
line = "\n".join(
147-
[
148-
i.rstrip()
149-
.replace(r"\n", r"\\n")
150-
.replace(r"\t", r"\\t")
151-
]
152-
).encode("utf-8")
153-
yield line
154+
async def aiter(list: T) -> AsyncGenerator[bytes, None]:
155+
for item in list:
156+
item_str: str
157+
if isinstance(item, bytes):
158+
item_str = item.decode("utf-8")
159+
elif isinstance(item, dict):
160+
item_str = orjson.dumps(item).decode("utf-8")
161+
elif isinstance(item, str):
162+
item_str = item
154163
else:
155-
raise Exception(f"Could not parse {i}")
164+
raise ValueError(f"Cannot load iterator with values of type {type(item)} (value {item})")
165+
156166

167+
line = "\n".join(
168+
[
169+
item_str.rstrip()
170+
.replace(r"\n", r"\\n")
171+
.replace(r"\t", r"\\t")
172+
]
173+
).encode("utf-8")
174+
yield line
157175

158-
async def copy(iter, table: tables, conn: asyncpg.Connection):
176+
177+
async def copy(iter: T, table: tables, conn: asyncpg.Connection) -> None:
159178
logger.debug(f"copying to {table} directly")
160179
logger.debug(f"iter: {iter}")
161-
iter = aiter(iter)
180+
bytes_iter = aiter(iter)
162181
async with conn.transaction():
163182
logger.debug("Copying data")
164183
await conn.copy_to_table(
165184
table,
166-
source=iter,
185+
source=bytes_iter,
167186
columns=["content"],
168187
format="csv",
169188
quote=chr(27),
@@ -179,10 +198,10 @@ async def copy(iter, table: tables, conn: asyncpg.Connection):
179198

180199

181200
async def copy_ignore_duplicates(
182-
iter, table: tables, conn: asyncpg.Connection
183-
):
201+
iter: T, table: tables, conn: asyncpg.Connection
202+
) -> None:
184203
logger.debug(f"inserting to {table} ignoring duplicates")
185-
iter = aiter(iter)
204+
bytes_iter = aiter(iter)
186205
async with conn.transaction():
187206
await conn.execute(
188207
"""
@@ -192,7 +211,7 @@ async def copy_ignore_duplicates(
192211
)
193212
await conn.copy_to_table(
194213
"pgstactemp",
195-
source=iter,
214+
source=bytes_iter,
196215
columns=["content"],
197216
format="csv",
198217
quote=chr(27),
@@ -218,9 +237,9 @@ async def copy_ignore_duplicates(
218237
logger.debug("Data Inserted")
219238

220239

221-
async def copy_upsert(iter, table: tables, conn: asyncpg.Connection):
240+
async def copy_upsert(iter: T, table: tables, conn: asyncpg.Connection) -> None:
222241
logger.debug(f"upserting to {table}")
223-
iter = aiter(iter)
242+
bytes_iter = aiter(iter)
224243
async with conn.transaction():
225244
await conn.execute(
226245
"""
@@ -230,7 +249,7 @@ async def copy_upsert(iter, table: tables, conn: asyncpg.Connection):
230249
)
231250
await conn.copy_to_table(
232251
"pgstactemp",
233-
source=iter,
252+
source=bytes_iter,
234253
columns=["content"],
235254
format="csv",
236255
quote=chr(27),
@@ -258,24 +277,28 @@ async def copy_upsert(iter, table: tables, conn: asyncpg.Connection):
258277

259278

260279
async def load_iterator(
261-
iter, table: tables, conn: asyncpg.Connection, method: loadopt = "insert"
280+
iter: T, table: tables, conn: asyncpg.Connection, method: loadopt = loadopt.insert
262281
):
263282
logger.debug(f"Load Iterator Connection: {conn}")
264-
if method == "insert":
283+
if method == loadopt.insert:
265284
await copy(iter, table, conn)
266-
elif method == "insert_ignore":
285+
elif method == loadopt.insert_ignore:
267286
await copy_ignore_duplicates(iter, table, conn)
268287
else:
269288
await copy_upsert(iter, table, conn)
270289

271290

272291
async def load_ndjson(
273-
file: str, table: tables, method: loadopt = "insert", dsn: str = None
274-
):
292+
file: str, table: tables, method: loadopt = loadopt.insert, dsn: str = None
293+
) -> None:
275294
print(f"loading {file} into {table} using {method}")
276-
with open(file, "rb") as f:
277-
async with DB(dsn) as conn:
278-
await load_iterator(f, table, conn, method)
295+
open_file = open(file, "rb")
296+
if isinstance(open_file, TextIOWrapper):
297+
with open_file as f:
298+
async with DB(dsn) as conn:
299+
await load_iterator(f, table, conn, method)
300+
else:
301+
raise IOError(f"Cannot read {file}")
279302

280303

281304
@app.command()
@@ -286,13 +309,38 @@ def load(
286309
method: loadopt = typer.Option(
287310
"insert", prompt="How to deal conflicting ids"
288311
),
289-
):
312+
) -> None:
313+
"Load STAC data into a pgstac database."
290314
typer.echo(
291315
asyncio.run(
292316
load_ndjson(file=file, table=table, dsn=dsn, method=method)
293317
)
294318
)
295319

320+
@app.command()
321+
def pgready(dsn: Optional[str] = None) -> None:
322+
"""Wait for a pgstac database to accept connections"""
323+
async def wait_on_connection() -> bool:
324+
cnt = 0
325+
326+
print("Waiting for pgstac to come online...", end="", flush=True)
327+
while True:
328+
if cnt > 150:
329+
raise Exception("Unable to connect to database")
330+
try:
331+
print(".", end="", flush=True)
332+
conn = await asyncpg.connect()
333+
await conn.execute("SELECT 1")
334+
await conn.close()
335+
print("success!")
336+
return True
337+
except Exception:
338+
time.sleep(0.1)
339+
cnt += 1
340+
341+
342+
asyncio.run(wait_on_connection())
343+
296344

297345
if __name__ == "__main__":
298346
app()

0 commit comments

Comments
 (0)