2121
2222import logging
2323
24+ from zenml .cli import secret
25+
2426# Configure logging levels for specific modules
2527logging .getLogger ("pytorch" ).setLevel (logging .CRITICAL )
2628logging .getLogger ("sentence-transformers" ).setLevel (logging .CRITICAL )
@@ -212,48 +214,76 @@ def split_documents(
212214 return chunked_documents
213215
214216
215- def get_local_db_connection_details ( ) -> Dict [ str , str ] :
216- """Returns the connection details for the local database.
217+ def get_db_password ( secret_name : str ) -> str :
218+ """Returns the password for the PostgreSQL database.
217219
218220 Returns:
219- dict: A dictionary containing the connection details for the local
220- database.
221+ str: The password for the PostgreSQL database.
222+ """
223+ password = os .getenv ("ZENML_POSTGRES_DB_PASSWORD" )
224+ if not password :
225+ from zenml .client import Client
221226
222- Raises:
223- RuntimeError: If the environment variables ZENML_POSTGRES_USER, ZENML_POSTGRES_HOST, or ZENML_POSTGRES_PORT are not set.
227+ password = (
228+ Client ()
229+ .get_secret (secret_name )
230+ .secret_values ["password" ]
231+ )
232+ return password
233+
234+
235+ def get_db_user (secret_name : str ) -> str :
236+ """Returns the user for the PostgreSQL database.
237+
238+ Returns:
239+ str: The user for the PostgreSQL database.
224240 """
225241 user = os .getenv ("ZENML_POSTGRES_USER" )
226- host = os . getenv ( "ZENML_POSTGRES_HOST" )
227- port = os . getenv ( "ZENML_POSTGRES_PORT" )
242+ if not user :
243+ from zenml . client import Client
228244
229- if not user or not host or not port :
230- raise RuntimeError (
231- "Please make sure to set the environment variables: ZENML_POSTGRES_USER, ZENML_POSTGRES_HOST, and ZENML_POSTGRES_PORT"
245+ user = (
246+ Client ()
247+ .get_secret (secret_name )
248+ .secret_values ["user" ]
232249 )
250+ return user
233251
234- return {
235- "user" : user ,
236- "host" : host ,
237- "port" : port ,
238- }
239252
253+ def get_db_host (secret_name : str ) -> str :
254+ """Returns the host for the PostgreSQL database.
240255
241- def get_db_password () -> str :
242- """Returns the password for the PostgreSQL database.
256+ Returns:
257+ str: The host for the PostgreSQL database.
258+ """
259+ host = os .getenv ("ZENML_POSTGRES_HOST" )
260+ if not host :
261+ from zenml .client import Client
262+
263+ host = (
264+ Client ()
265+ .get_secret (secret_name )
266+ .secret_values ["host" ]
267+ )
268+ return host
269+
270+
271+ def get_db_port (secret_name : str ) -> str :
272+ """Returns the port for the PostgreSQL database.
243273
244274 Returns:
245- str: The password for the PostgreSQL database.
275+ str: The port for the PostgreSQL database.
246276 """
247- password = os .getenv ("ZENML_POSTGRES_DB_PASSWORD" )
248- if not password :
277+ port = os .getenv ("ZENML_POSTGRES_DB_PASSWORD" )
278+ if not port :
249279 from zenml .client import Client
250280
251- password = (
281+ port = (
252282 Client ()
253283 .get_secret ("supabase_postgres_db" )
254- .secret_values ["password " ]
284+ .secret_values ["port " ]
255285 )
256- return password
286+ return port
257287
258288
259289def get_db_conn () -> connection :
@@ -265,15 +295,19 @@ def get_db_conn() -> connection:
265295 Returns:
266296 connection: A psycopg2 connection object to the PostgreSQL database.
267297 """
268- pg_password = get_db_password ( )
298+ secret_name = os . getenv ( "ZENML_SUPABASE_SECRET_NAME" )
269299
270- local_database_connection = get_local_db_connection_details ()
300+ if not secret_name :
301+ raise RuntimeError (
302+ "Please make sure to set the environment variable: ZENML_SUPABASE_SECRET_NAME to point at the secret that "
303+ "contains your supabase connection details."
304+ )
271305
272306 CONNECTION_DETAILS = {
273- "user" : local_database_connection [ "user" ] ,
274- "password" : pg_password ,
275- "host" : local_database_connection [ "host" ] ,
276- "port" : local_database_connection [ "port" ] ,
307+ "user" : get_db_user ( secret_name ) ,
308+ "password" : get_db_password ( secret_name ) ,
309+ "host" : get_db_host ( secret_name ) ,
310+ "port" : get_db_port ( secret_name ) ,
277311 "dbname" : "postgres" ,
278312 }
279313
0 commit comments