Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions sync2jira/downstream_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,50 @@ def validate_github_url(url):


def get_snowflake_conn():
"""Get Snowflake connection - lazy initialization"""

return snowflake.connector.connect(
account=os.getenv("SNOWFLAKE_ACCOUNT"),
user=os.getenv("SNOWFLAKE_USER"),
password=os.getenv("SNOWFLAKE_PAT"),
role=os.getenv("SNOWFLAKE_ROLE"),
warehouse=os.getenv("SNOWFLAKE_WAREHOUSE", "DEFAULT"),
database=os.getenv("SNOWFLAKE_DATABASE", "JIRA_DB"),
schema=os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC"),
paramstyle="qmark",
)
"""Get Snowflake connection - lazy initialization

Supports two authentication methods:
1. JWT authentication with private key file (if SNOWFLAKE_PRIVATE_KEY_FILE is set)
2. Password authentication with PAT (if SNOWFLAKE_PAT is set)
"""
account = os.getenv("SNOWFLAKE_ACCOUNT")
user = os.getenv("SNOWFLAKE_USER")
role = os.getenv("SNOWFLAKE_ROLE")
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE", "DEFAULT")
database = os.getenv("SNOWFLAKE_DATABASE", "JIRA_DB")
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")

# Build base connection parameters
conn_params = {
"account": account,
"user": user,
"role": role,
"warehouse": warehouse,
"database": database,
"schema": schema,
"paramstyle": "qmark",
}

# Check for private key file (JWT authentication)
private_key_file = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE")
if private_key_file:
conn_params["authenticator"] = "SNOWFLAKE_JWT"
conn_params["private_key_file"] = private_key_file

# Add private key file password if specified
private_key_file_pwd = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE_PWD")
if private_key_file_pwd:
conn_params["private_key_file_pwd"] = private_key_file_pwd
else:
# Fall back to password authentication
password = os.getenv("SNOWFLAKE_PAT")
if not password:
raise ValueError(
"Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set"
)
conn_params["password"] = password

return snowflake.connector.connect(**conn_params)


def execute_snowflake_query(issue):
Expand Down
113 changes: 113 additions & 0 deletions tests/test_downstream_issue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timezone
import os
from typing import Any, Optional
import unittest
import unittest.mock as mock
Expand Down Expand Up @@ -1758,6 +1759,15 @@ def test_remove_diacritics(self):
self.assertEqual(actual, expected)

@mock.patch(PATH + "snowflake.connector.connect")
@mock.patch.dict(
os.environ,
{
"SNOWFLAKE_ACCOUNT": "test_account",
"SNOWFLAKE_USER": "test_user",
"SNOWFLAKE_ROLE": "test_role",
"SNOWFLAKE_PAT": "fake_password",
},
)
def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect):
"""Test execute_snowflake_query function."""
# Create a mock issue
Expand All @@ -1770,8 +1780,111 @@ def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect):
)
# Assert the function was called correctly
mock_snowflake_connect.assert_called_once()
# Verify password authentication is used
call_args = mock_snowflake_connect.call_args[1]
self.assertEqual(call_args["password"], os.getenv("SNOWFLAKE_PAT"))
self.assertNotIn("authenticator", call_args)
self.assertNotIn("private_key_file", call_args)
mock_cursor.fetchall.assert_called_once()
mock_cursor.close.assert_called_once()
# Assert the result
self.assertEqual(result, mock_cursor.fetchall.return_value)

@mock.patch(PATH + "snowflake.connector.connect")
@mock.patch.dict(
os.environ,
{
"SNOWFLAKE_ACCOUNT": "test_account",
"SNOWFLAKE_USER": "test_user",
"SNOWFLAKE_ROLE": "test_role",
"SNOWFLAKE_PRIVATE_KEY_FILE": "test_key.pem",
},
)
@mock.patch("os.path.exists")
def test_execute_snowflake_query_with_jwt_auth(
self, mock_exists, mock_snowflake_connect
):
"""Test execute_snowflake_query with JWT authentication."""
mock_exists.return_value = True
# Create a mock issue
mock_issue = MagicMock()
mock_issue.url = "https://github.com/test/repo/issues/1"
# Call the function
result = d.execute_snowflake_query(mock_issue)
mock_cursor = (
mock_snowflake_connect.return_value.__enter__.return_value.cursor.return_value
)
# Assert the function was called correctly
mock_snowflake_connect.assert_called_once()
# Verify JWT authentication is used
call_args = mock_snowflake_connect.call_args[1]
self.assertEqual(call_args["authenticator"], "SNOWFLAKE_JWT")
self.assertEqual(
call_args["private_key_file"], os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE")
)
self.assertNotIn("password", call_args)
self.assertNotIn("private_key_file_pwd", call_args)
mock_cursor.execute.assert_called_once()
mock_cursor.fetchall.assert_called_once()
mock_cursor.close.assert_called_once()
# Assert the result
self.assertEqual(result, mock_cursor.fetchall.return_value)

@mock.patch(PATH + "snowflake.connector.connect")
@mock.patch.dict(
os.environ,
{
"SNOWFLAKE_ACCOUNT": "test_account",
"SNOWFLAKE_USER": "test_user",
"SNOWFLAKE_ROLE": "test_role",
"SNOWFLAKE_PRIVATE_KEY_FILE": "test_key.pem",
"SNOWFLAKE_PRIVATE_KEY_FILE_PWD": "key_password",
},
)
@mock.patch("os.path.exists")
def test_execute_snowflake_query_with_jwt_auth_and_password(
self, mock_exists, mock_snowflake_connect
):
"""Test execute_snowflake_query with JWT authentication and key password."""
mock_exists.return_value = True
# Create a mock issue
mock_issue = MagicMock()
mock_issue.url = "https://github.com/test/repo/issues/1"
# Call the function
result = d.execute_snowflake_query(mock_issue)
mock_cursor = (
mock_snowflake_connect.return_value.__enter__.return_value.cursor.return_value
)
# Assert the function was called correctly
mock_snowflake_connect.assert_called_once()
# Verify JWT authentication with password is used
call_args = mock_snowflake_connect.call_args[1]
self.assertEqual(call_args["authenticator"], "SNOWFLAKE_JWT")
self.assertEqual(
call_args["private_key_file"], os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE")
)
self.assertEqual(
call_args["private_key_file_pwd"],
os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE_PWD"),
)
self.assertNotIn("password", call_args)
mock_cursor.execute.assert_called_once()
mock_cursor.fetchall.assert_called_once()
mock_cursor.close.assert_called_once()
# Assert the result
self.assertEqual(result, mock_cursor.fetchall.return_value)

@mock.patch.dict(os.environ, {}, clear=True)
def test_execute_snowflake_query_no_credentials(self):
"""Test execute_snowflake_query raises error when no credentials are set."""
# Create a mock issue
mock_issue = MagicMock()
mock_issue.url = "https://github.com/test/repo/issues/1"

with self.assertRaises(ValueError) as context:
d.execute_snowflake_query(mock_issue)

self.assertIn(
"Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set",
str(context.exception),
)