diff --git a/sync2jira/downstream_issue.py b/sync2jira/downstream_issue.py index ed2704ae..a6f7f67f 100644 --- a/sync2jira/downstream_issue.py +++ b/sync2jira/downstream_issue.py @@ -76,18 +76,50 @@ def validate_github_url(url): def get_snowflake_conn(): - """Get Snowflake connection - lazy initialization""" - - return snowflake.connector.connect( - account=os.getenv("SNOWFLAKE_ACCOUNT"), - user=os.getenv("SNOWFLAKE_USER"), - password=os.getenv("SNOWFLAKE_PAT"), - role=os.getenv("SNOWFLAKE_ROLE"), - warehouse=os.getenv("SNOWFLAKE_WAREHOUSE", "DEFAULT"), - database=os.getenv("SNOWFLAKE_DATABASE", "JIRA_DB"), - schema=os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC"), - paramstyle="qmark", - ) + """Get Snowflake connection - lazy initialization + + Supports two authentication methods: + 1. JWT authentication with private key file (if SNOWFLAKE_PRIVATE_KEY_FILE is set) + 2. Password authentication with PAT (if SNOWFLAKE_PAT is set) + """ + account = os.getenv("SNOWFLAKE_ACCOUNT") + user = os.getenv("SNOWFLAKE_USER") + role = os.getenv("SNOWFLAKE_ROLE") + warehouse = os.getenv("SNOWFLAKE_WAREHOUSE", "DEFAULT") + database = os.getenv("SNOWFLAKE_DATABASE", "JIRA_DB") + schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC") + + # Build base connection parameters + conn_params = { + "account": account, + "user": user, + "role": role, + "warehouse": warehouse, + "database": database, + "schema": schema, + "paramstyle": "qmark", + } + + # Check for private key file (JWT authentication) + private_key_file = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE") + if private_key_file: + conn_params["authenticator"] = "SNOWFLAKE_JWT" + conn_params["private_key_file"] = private_key_file + + # Add private key file password if specified + private_key_file_pwd = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE_PWD") + if private_key_file_pwd: + conn_params["private_key_file_pwd"] = private_key_file_pwd + else: + # Fall back to password authentication + password = os.getenv("SNOWFLAKE_PAT") + if not password: + raise ValueError( + "Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set" + ) + conn_params["password"] = password + + return snowflake.connector.connect(**conn_params) def execute_snowflake_query(issue): diff --git a/tests/test_downstream_issue.py b/tests/test_downstream_issue.py index f6343a61..e6bad697 100644 --- a/tests/test_downstream_issue.py +++ b/tests/test_downstream_issue.py @@ -1,4 +1,5 @@ from datetime import datetime, timezone +import os from typing import Any, Optional import unittest import unittest.mock as mock @@ -1758,6 +1759,15 @@ def test_remove_diacritics(self): self.assertEqual(actual, expected) @mock.patch(PATH + "snowflake.connector.connect") + @mock.patch.dict( + os.environ, + { + "SNOWFLAKE_ACCOUNT": "test_account", + "SNOWFLAKE_USER": "test_user", + "SNOWFLAKE_ROLE": "test_role", + "SNOWFLAKE_PAT": "fake_password", + }, + ) def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect): """Test execute_snowflake_query function.""" # Create a mock issue @@ -1770,8 +1780,111 @@ def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect): ) # Assert the function was called correctly mock_snowflake_connect.assert_called_once() + # Verify password authentication is used + call_args = mock_snowflake_connect.call_args[1] + self.assertEqual(call_args["password"], os.getenv("SNOWFLAKE_PAT")) + self.assertNotIn("authenticator", call_args) + self.assertNotIn("private_key_file", call_args) + mock_cursor.fetchall.assert_called_once() + mock_cursor.close.assert_called_once() + # Assert the result + self.assertEqual(result, mock_cursor.fetchall.return_value) + + @mock.patch(PATH + "snowflake.connector.connect") + @mock.patch.dict( + os.environ, + { + "SNOWFLAKE_ACCOUNT": "test_account", + "SNOWFLAKE_USER": "test_user", + "SNOWFLAKE_ROLE": "test_role", + "SNOWFLAKE_PRIVATE_KEY_FILE": "test_key.pem", + }, + ) + @mock.patch("os.path.exists") + def test_execute_snowflake_query_with_jwt_auth( + self, mock_exists, mock_snowflake_connect + ): + """Test execute_snowflake_query with JWT authentication.""" + mock_exists.return_value = True + # Create a mock issue + mock_issue = MagicMock() + mock_issue.url = "https://github.com/test/repo/issues/1" + # Call the function + result = d.execute_snowflake_query(mock_issue) + mock_cursor = ( + mock_snowflake_connect.return_value.__enter__.return_value.cursor.return_value + ) + # Assert the function was called correctly + mock_snowflake_connect.assert_called_once() + # Verify JWT authentication is used + call_args = mock_snowflake_connect.call_args[1] + self.assertEqual(call_args["authenticator"], "SNOWFLAKE_JWT") + self.assertEqual( + call_args["private_key_file"], os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE") + ) + self.assertNotIn("password", call_args) + self.assertNotIn("private_key_file_pwd", call_args) + mock_cursor.execute.assert_called_once() + mock_cursor.fetchall.assert_called_once() + mock_cursor.close.assert_called_once() + # Assert the result + self.assertEqual(result, mock_cursor.fetchall.return_value) + + @mock.patch(PATH + "snowflake.connector.connect") + @mock.patch.dict( + os.environ, + { + "SNOWFLAKE_ACCOUNT": "test_account", + "SNOWFLAKE_USER": "test_user", + "SNOWFLAKE_ROLE": "test_role", + "SNOWFLAKE_PRIVATE_KEY_FILE": "test_key.pem", + "SNOWFLAKE_PRIVATE_KEY_FILE_PWD": "key_password", + }, + ) + @mock.patch("os.path.exists") + def test_execute_snowflake_query_with_jwt_auth_and_password( + self, mock_exists, mock_snowflake_connect + ): + """Test execute_snowflake_query with JWT authentication and key password.""" + mock_exists.return_value = True + # Create a mock issue + mock_issue = MagicMock() + mock_issue.url = "https://github.com/test/repo/issues/1" + # Call the function + result = d.execute_snowflake_query(mock_issue) + mock_cursor = ( + mock_snowflake_connect.return_value.__enter__.return_value.cursor.return_value + ) + # Assert the function was called correctly + mock_snowflake_connect.assert_called_once() + # Verify JWT authentication with password is used + call_args = mock_snowflake_connect.call_args[1] + self.assertEqual(call_args["authenticator"], "SNOWFLAKE_JWT") + self.assertEqual( + call_args["private_key_file"], os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE") + ) + self.assertEqual( + call_args["private_key_file_pwd"], + os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE_PWD"), + ) + self.assertNotIn("password", call_args) mock_cursor.execute.assert_called_once() mock_cursor.fetchall.assert_called_once() mock_cursor.close.assert_called_once() # Assert the result self.assertEqual(result, mock_cursor.fetchall.return_value) + + @mock.patch.dict(os.environ, {}, clear=True) + def test_execute_snowflake_query_no_credentials(self): + """Test execute_snowflake_query raises error when no credentials are set.""" + # Create a mock issue + mock_issue = MagicMock() + mock_issue.url = "https://github.com/test/repo/issues/1" + + with self.assertRaises(ValueError) as context: + d.execute_snowflake_query(mock_issue) + + self.assertIn( + "Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set", + str(context.exception), + )