3434
3535def lenient_connection (settings : Settings , retries = 5 ):
3636 try :
37- return psycopg2 .connect (
38- password = settings .pg_password ,
39- host = settings .pg_host ,
40- port = settings .pg_port ,
41- user = settings .pg_user ,
42- )
37+ return psycopg2 .connect (password = settings .pg_password , dsn = settings .pg_dsn ,)
4338 except psycopg2 .Error as e :
4439 if retries <= 0 :
4540 raise
@@ -73,14 +68,14 @@ def populate_db(engine):
7368"""
7469
7570
76- def prepare_database (delete_existing : Union [bool , callable ]) -> bool :
71+ def prepare_database (delete_existing : Union [bool , callable ], settings : Settings = None ) -> bool :
7772 """
7873 (Re)create a fresh database and run migrations.
7974
8075 :param delete_existing: whether or not to drop an existing database if it exists
8176 :return: whether or not a database as (re)created
8277 """
83- settings = Settings ()
78+ settings = settings or Settings ()
8479
8580 with psycopg2_cursor (settings ) as cur :
8681 cur .execute ('SELECT EXISTS (SELECT datname FROM pg_catalog.pg_database WHERE datname=%s)' , (settings .pg_name ,))
@@ -96,13 +91,12 @@ def prepare_database(delete_existing: Union[bool, callable]) -> bool:
9691 else :
9792 print (f'dropping existing connections to "{ settings .pg_name } "...' )
9893 cur .execute (DROP_CONNECTIONS , (settings .pg_name ,))
99- print (f'dropping database "{ settings .pg_name } " as it already exists...' )
100- cur .execute (f'DROP DATABASE { settings .pg_name } ' )
101- else :
102- print (f'database "{ settings .pg_name } " does not yet exist' )
10394
104- print (f'creating database "{ settings .pg_name } "...' )
105- cur .execute (f'CREATE DATABASE { settings .pg_name } ' )
95+ logger .debug ('dropping and re-creating the schema...' )
96+ cur .execute ('drop schema public cascade;\n create schema public;' )
97+ else :
98+ print (f'database "{ settings .pg_name } " does not yet exist, creating' )
99+ cur .execute (f'CREATE DATABASE { settings .pg_name } ' )
106100
107101 engine = create_engine (settings .pg_dsn )
108102 print ('creating tables from model definition...' )
@@ -122,9 +116,11 @@ def patch(func):
122116
123117def run_patch (live , patch_name ):
124118 if patch_name is None :
125- print ('available patches:\n {}' .format (
126- '\n ' .join (' {}: {}' .format (p .__name__ , p .__doc__ .strip ('\n ' )) for p in patches )
127- ))
119+ print (
120+ 'available patches:\n {}' .format (
121+ '\n ' .join (' {}: {}' .format (p .__name__ , p .__doc__ .strip ('\n ' )) for p in patches )
122+ )
123+ )
128124 return
129125 patch_lookup = {p .__name__ : p for p in patches }
130126 try :
@@ -168,8 +164,11 @@ def print_tables(conn):
168164 'float8' : 'FLOAT' ,
169165 }
170166 for table_name , * _ in result :
171- r = conn .execute ("SELECT column_name, udt_name, character_maximum_length, is_nullable, column_default "
172- "FROM information_schema.columns WHERE table_name=%s" , table_name )
167+ r = conn .execute (
168+ "SELECT column_name, udt_name, character_maximum_length, is_nullable, column_default "
169+ "FROM information_schema.columns WHERE table_name=%s" ,
170+ table_name ,
171+ )
173172 fields = []
174173 for name , col_type , max_chars , nullable , dft in r :
175174 col_type = type_lookup .get (col_type , col_type .upper ())
@@ -204,11 +203,13 @@ def add_labels(conn):
204203 add labels field to contractors
205204 """
206205 conn .execute ('ALTER TABLE contractors ADD labels VARCHAR(255)[]' )
207- conn .execute ("""
206+ conn .execute (
207+ """
208208 CREATE INDEX ix_contractors_labels
209209 ON contractors
210210 USING btree (labels);
211- """ )
211+ """
212+ )
212213
213214
214215@patch
@@ -220,11 +221,9 @@ def add_domains_options(conn):
220221 conn .execute ('ALTER TABLE companies ADD options JSONB' )
221222 updated = 0
222223 for id , domain in conn .execute ('SELECT id, domain FROM companies WHERE domain IS NOT NULL' ):
223- conn .execute ((
224- update (sa_companies )
225- .values ({'domains' : [domain , 'www.' + domain ]})
226- .where (sa_companies .c .id == id )
227- ))
224+ conn .execute (
225+ (update (sa_companies ).values ({'domains' : [domain , 'www.' + domain ]}).where (sa_companies .c .id == id ))
226+ )
228227 updated += 1
229228 print (f'domains updated for { updated } companies' )
230229 conn .execute ('ALTER TABLE companies DROP COLUMN domain' )
0 commit comments