11import asyncio
2+ from io import TextIOWrapper
23import os
3- from typing import List
4+ import time
5+ from typing import Any , AsyncGenerator , Dict , Iterable , List , Optional , TypeVar , Union
46
57import asyncpg
8+ from asyncpg .connection import Connection
69import typer
710import orjson
811from smart_open import open
2124migrations_dir = os .path .join (dirname , "migrations" )
2225
2326
24- def pglogger (conn , message ) :
27+ def pglogger (message : str ) -> None :
2528 logging .debug (message )
2629
2730
28- async def con_init (conn ) :
31+ async def con_init (conn : Connection ) -> None :
2932 """Use orjson for json returns."""
3033 await conn .set_type_codec (
3134 "json" ,
@@ -42,33 +45,36 @@ async def con_init(conn):
4245
4346
4447class DB :
45- pg_connection_string = None
46- connection = None
48+ pg_connection_string : Optional [ str ] = None
49+ connection : Optional [ Connection ] = None
4750
48- def __init__ (self , pg_connection_string : str = None ):
51+ def __init__ (self , pg_connection_string : Optional [ str ] = None ) -> None :
4952 self .pg_connection_string = pg_connection_string
5053
51- async def create_connection (self ):
52- self . connection = await asyncpg .connect (
54+ async def create_connection (self ) -> Connection :
55+ connection : Connection = await asyncpg .connect (
5356 self .pg_connection_string ,
5457 server_settings = {
5558 "search_path" : "pgstac,public" ,
5659 "application_name" : "pypgstac" ,
5760 },
5861 )
59- await con_init (self .connection )
62+ await con_init (connection )
63+ self .connection = connection
6064 return self .connection
6165
62- async def __aenter__ (self ):
66+ async def __aenter__ (self ) -> Connection :
6367 if self .connection is None :
6468 await self .create_connection ()
69+ assert self .connection is not None
6570 return self .connection
6671
67- async def __aexit__ (self , exc_type , exc_val , exc_tb ):
68- await self .connection .close ()
72+ async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
73+ if self .connection :
74+ await self .connection .close ()
6975
7076
71- async def run_migration (dsn : str = None ):
77+ async def run_migration (dsn : Optional [ str ] = None ) -> str :
7278 conn = await asyncpg .connect (dsn = dsn )
7379 async with conn .transaction ():
7480 try :
@@ -102,27 +108,33 @@ async def run_migration(dsn: str = None):
102108 f"Pypgstac does not have a migration from { oldversion } to { version } ({ migration_file } )"
103109 )
104110
105- with open (migration_file ) as f :
106- migration_sql = f .read ()
107- logging .debug (migration_sql )
108- async with conn .transaction ():
109- conn .add_log_listener (pglogger )
110- await conn .execute (migration_sql )
111- await conn .execute (
112- """
113- INSERT INTO pgstac.migrations (version)
114- VALUES ($1);
115- """ ,
116- version ,
117- )
118-
119- await conn .close ()
111+ open_migration_file = open (migration_file )
112+ if isinstance (open_migration_file , TextIOWrapper ):
113+ with open_migration_file as f :
114+ migration_sql = f .read ()
115+ logging .debug (migration_sql )
116+ async with conn .transaction ():
117+ conn .add_log_listener (pglogger )
118+ await conn .execute (migration_sql )
119+ await conn .execute (
120+ """
121+ INSERT INTO pgstac.migrations (version)
122+ VALUES ($1);
123+ """ ,
124+ version ,
125+ )
126+
127+ await conn .close ()
128+ else :
129+ raise IOError (f"Unable to open { migration_file } " )
120130 return version
121131
122132
123133@app .command ()
124- def migrate (dsn : str = None ):
125- typer .echo (asyncio .run (run_migration (dsn )))
134+ def migrate (dsn : Optional [str ] = None ) -> None :
135+ """Migrate a pgstac database"""
136+ version = asyncio .run (run_migration (dsn ))
137+ typer .echo (f'pgstac version { version } ' )
126138
127139
128140class loadopt (str , Enum ):
@@ -135,35 +147,42 @@ class tables(str, Enum):
135147 items = "items"
136148 collections = "collections"
137149
150+ # Types of iterable that load_iterator can support
151+ T = TypeVar ('T' , Iterable [bytes ], Iterable [Dict [str , Any ]], Iterable [str ])
152+
138153
139- async def aiter (list : List ):
140- for i in list :
141- if isinstance (i , bytes ):
142- i = i .decode ("utf-8" )
143- elif isinstance (i , dict ):
144- i = orjson .dumps (i ).decode ("utf-8" )
145- if isinstance (i , str ):
146- line = "\n " .join (
147- [
148- i .rstrip ()
149- .replace (r"\n" , r"\\n" )
150- .replace (r"\t" , r"\\t" )
151- ]
152- ).encode ("utf-8" )
153- yield line
154+ async def aiter (list : T ) -> AsyncGenerator [bytes , None ]:
155+ for item in list :
156+ item_str : str
157+ if isinstance (item , bytes ):
158+ item_str = item .decode ("utf-8" )
159+ elif isinstance (item , dict ):
160+ item_str = orjson .dumps (item ).decode ("utf-8" )
161+ elif isinstance (item , str ):
162+ item_str = item
154163 else :
155- raise Exception (f"Could not parse { i } " )
164+ raise ValueError (f"Cannot load iterator with values of type { type (item )} (value { item } )" )
165+
156166
167+ line = "\n " .join (
168+ [
169+ item_str .rstrip ()
170+ .replace (r"\n" , r"\\n" )
171+ .replace (r"\t" , r"\\t" )
172+ ]
173+ ).encode ("utf-8" )
174+ yield line
157175
158- async def copy (iter , table : tables , conn : asyncpg .Connection ):
176+
177+ async def copy (iter : T , table : tables , conn : asyncpg .Connection ) -> None :
159178 logger .debug (f"copying to { table } directly" )
160179 logger .debug (f"iter: { iter } " )
161- iter = aiter (iter )
180+ bytes_iter = aiter (iter )
162181 async with conn .transaction ():
163182 logger .debug ("Copying data" )
164183 await conn .copy_to_table (
165184 table ,
166- source = iter ,
185+ source = bytes_iter ,
167186 columns = ["content" ],
168187 format = "csv" ,
169188 quote = chr (27 ),
@@ -179,10 +198,10 @@ async def copy(iter, table: tables, conn: asyncpg.Connection):
179198
180199
181200async def copy_ignore_duplicates (
182- iter , table : tables , conn : asyncpg .Connection
183- ):
201+ iter : T , table : tables , conn : asyncpg .Connection
202+ ) -> None :
184203 logger .debug (f"inserting to { table } ignoring duplicates" )
185- iter = aiter (iter )
204+ bytes_iter = aiter (iter )
186205 async with conn .transaction ():
187206 await conn .execute (
188207 """
@@ -192,7 +211,7 @@ async def copy_ignore_duplicates(
192211 )
193212 await conn .copy_to_table (
194213 "pgstactemp" ,
195- source = iter ,
214+ source = bytes_iter ,
196215 columns = ["content" ],
197216 format = "csv" ,
198217 quote = chr (27 ),
@@ -218,9 +237,9 @@ async def copy_ignore_duplicates(
218237 logger .debug ("Data Inserted" )
219238
220239
221- async def copy_upsert (iter , table : tables , conn : asyncpg .Connection ):
240+ async def copy_upsert (iter : T , table : tables , conn : asyncpg .Connection ) -> None :
222241 logger .debug (f"upserting to { table } " )
223- iter = aiter (iter )
242+ bytes_iter = aiter (iter )
224243 async with conn .transaction ():
225244 await conn .execute (
226245 """
@@ -230,7 +249,7 @@ async def copy_upsert(iter, table: tables, conn: asyncpg.Connection):
230249 )
231250 await conn .copy_to_table (
232251 "pgstactemp" ,
233- source = iter ,
252+ source = bytes_iter ,
234253 columns = ["content" ],
235254 format = "csv" ,
236255 quote = chr (27 ),
@@ -258,24 +277,28 @@ async def copy_upsert(iter, table: tables, conn: asyncpg.Connection):
258277
259278
260279async def load_iterator (
261- iter , table : tables , conn : asyncpg .Connection , method : loadopt = " insert"
280+ iter : T , table : tables , conn : asyncpg .Connection , method : loadopt = loadopt . insert
262281):
263282 logger .debug (f"Load Iterator Connection: { conn } " )
264- if method == " insert" :
283+ if method == loadopt . insert :
265284 await copy (iter , table , conn )
266- elif method == " insert_ignore" :
285+ elif method == loadopt . insert_ignore :
267286 await copy_ignore_duplicates (iter , table , conn )
268287 else :
269288 await copy_upsert (iter , table , conn )
270289
271290
272291async def load_ndjson (
273- file : str , table : tables , method : loadopt = " insert" , dsn : str = None
274- ):
292+ file : str , table : tables , method : loadopt = loadopt . insert , dsn : str = None
293+ ) -> None :
275294 print (f"loading { file } into { table } using { method } " )
276- with open (file , "rb" ) as f :
277- async with DB (dsn ) as conn :
278- await load_iterator (f , table , conn , method )
295+ open_file = open (file , "rb" )
296+ if isinstance (open_file , TextIOWrapper ):
297+ with open_file as f :
298+ async with DB (dsn ) as conn :
299+ await load_iterator (f , table , conn , method )
300+ else :
301+ raise IOError (f"Cannot read { file } " )
279302
280303
281304@app .command ()
@@ -286,13 +309,38 @@ def load(
286309 method : loadopt = typer .Option (
287310 "insert" , prompt = "How to deal conflicting ids"
288311 ),
289- ):
312+ ) -> None :
313+ "Load STAC data into a pgstac database."
290314 typer .echo (
291315 asyncio .run (
292316 load_ndjson (file = file , table = table , dsn = dsn , method = method )
293317 )
294318 )
295319
320+ @app .command ()
321+ def pgready (dsn : Optional [str ] = None ) -> None :
322+ """Wait for a pgstac database to accept connections"""
323+ async def wait_on_connection () -> bool :
324+ cnt = 0
325+
326+ print ("Waiting for pgstac to come online..." , end = "" , flush = True )
327+ while True :
328+ if cnt > 150 :
329+ raise Exception ("Unable to connect to database" )
330+ try :
331+ print ("." , end = "" , flush = True )
332+ conn = await asyncpg .connect ()
333+ await conn .execute ("SELECT 1" )
334+ await conn .close ()
335+ print ("success!" )
336+ return True
337+ except Exception :
338+ time .sleep (0.1 )
339+ cnt += 1
340+
341+
342+ asyncio .run (wait_on_connection ())
343+
296344
297345if __name__ == "__main__" :
298346 app ()
0 commit comments