| 
1 | 1 | import asyncio  | 
2 |  | -from io import BufferedIOBase  | 
3 | 2 | import os  | 
4 | 3 | import time  | 
5 | 4 | from typing import Any, AsyncGenerator, Dict, Iterable, Optional, TypeVar  | 
@@ -109,25 +108,24 @@ async def run_migration(dsn: Optional[str] = None) -> str:  | 
109 | 108 |             f"from {oldversion} to {version} ({migration_file})"  | 
110 | 109 |         )  | 
111 | 110 | 
 
  | 
112 |  | -    open_migration_file = open(migration_file)  | 
113 |  | -    if isinstance(open_migration_file, BufferedIOBase):  | 
114 |  | -        with open_migration_file as f:  | 
115 |  | -            migration_sql = f.read()  | 
116 |  | -            logging.debug(migration_sql)  | 
117 |  | -            async with conn.transaction():  | 
118 |  | -                conn.add_log_listener(pglogger)  | 
119 |  | -                await conn.execute(migration_sql)  | 
120 |  | -                await conn.execute(  | 
121 |  | -                    """  | 
122 |  | -                    INSERT INTO pgstac.migrations (version)  | 
123 |  | -                    VALUES ($1);  | 
124 |  | -                    """,  | 
125 |  | -                    version,  | 
126 |  | -                )  | 
127 |  | - | 
128 |  | -        await conn.close()  | 
129 |  | -    else:  | 
130 |  | -        raise IOError(f"Unable to open {migration_file}")  | 
 | 111 | +    open_migration_file: Any = open(migration_file)  | 
 | 112 | + | 
 | 113 | +    with open_migration_file as f:  | 
 | 114 | +        migration_sql = f.read()  | 
 | 115 | +        logging.debug(migration_sql)  | 
 | 116 | +        async with conn.transaction():  | 
 | 117 | +            conn.add_log_listener(pglogger)  | 
 | 118 | +            await conn.execute(migration_sql)  | 
 | 119 | +            await conn.execute(  | 
 | 120 | +                """  | 
 | 121 | +                INSERT INTO pgstac.migrations (version)  | 
 | 122 | +                VALUES ($1);  | 
 | 123 | +                """,  | 
 | 124 | +                version,  | 
 | 125 | +            )  | 
 | 126 | + | 
 | 127 | +    await conn.close()  | 
 | 128 | + | 
131 | 129 |     return version  | 
132 | 130 | 
 
  | 
133 | 131 | 
 
  | 
@@ -293,13 +291,11 @@ async def load_ndjson(  | 
293 | 291 |     file: str, table: tables, method: loadopt = loadopt.insert, dsn: str = None  | 
294 | 292 | ) -> None:  | 
295 | 293 |     print(f"loading {file} into {table} using {method}")  | 
296 |  | -    open_file = open(file, "rb")  | 
297 |  | -    if isinstance(open_file, BufferedIOBase):  | 
298 |  | -        with open_file as f:  | 
299 |  | -            async with DB(dsn) as conn:  | 
300 |  | -                await load_iterator(f, table, conn, method)  | 
301 |  | -    else:  | 
302 |  | -        raise IOError(f"Cannot read {file}")  | 
 | 294 | +    open_file: Any = open(file, "rb")  | 
 | 295 | + | 
 | 296 | +    with open_file as f:  | 
 | 297 | +        async with DB(dsn) as conn:  | 
 | 298 | +            await load_iterator(f, table, conn, method)  | 
303 | 299 | 
 
  | 
304 | 300 | 
 
  | 
305 | 301 | @app.command()  | 
 | 
0 commit comments