Skip to content

Commit 81b05a0

Browse files
update code for connecting with snowflake account using private key. (#389)
* snowflake connection fix * updation in test_case * pipeline fix increasing the coverage * adding test cases * fixing black error
1 parent 2377324 commit 81b05a0

File tree

2 files changed

+157
-12
lines changed

2 files changed

+157
-12
lines changed

sync2jira/downstream_issue.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,50 @@ def validate_github_url(url):
7676

7777

7878
def get_snowflake_conn():
79-
"""Get Snowflake connection - lazy initialization"""
80-
81-
return snowflake.connector.connect(
82-
account=os.getenv("SNOWFLAKE_ACCOUNT"),
83-
user=os.getenv("SNOWFLAKE_USER"),
84-
password=os.getenv("SNOWFLAKE_PAT"),
85-
role=os.getenv("SNOWFLAKE_ROLE"),
86-
warehouse=os.getenv("SNOWFLAKE_WAREHOUSE", "DEFAULT"),
87-
database=os.getenv("SNOWFLAKE_DATABASE", "JIRA_DB"),
88-
schema=os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC"),
89-
paramstyle="qmark",
90-
)
79+
"""Get Snowflake connection - lazy initialization
80+
81+
Supports two authentication methods:
82+
1. JWT authentication with private key file (if SNOWFLAKE_PRIVATE_KEY_FILE is set)
83+
2. Password authentication with PAT (if SNOWFLAKE_PAT is set)
84+
"""
85+
account = os.getenv("SNOWFLAKE_ACCOUNT")
86+
user = os.getenv("SNOWFLAKE_USER")
87+
role = os.getenv("SNOWFLAKE_ROLE")
88+
warehouse = os.getenv("SNOWFLAKE_WAREHOUSE", "DEFAULT")
89+
database = os.getenv("SNOWFLAKE_DATABASE", "JIRA_DB")
90+
schema = os.getenv("SNOWFLAKE_SCHEMA", "PUBLIC")
91+
92+
# Build base connection parameters
93+
conn_params = {
94+
"account": account,
95+
"user": user,
96+
"role": role,
97+
"warehouse": warehouse,
98+
"database": database,
99+
"schema": schema,
100+
"paramstyle": "qmark",
101+
}
102+
103+
# Check for private key file (JWT authentication)
104+
private_key_file = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE")
105+
if private_key_file:
106+
conn_params["authenticator"] = "SNOWFLAKE_JWT"
107+
conn_params["private_key_file"] = private_key_file
108+
109+
# Add private key file password if specified
110+
private_key_file_pwd = os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE_PWD")
111+
if private_key_file_pwd:
112+
conn_params["private_key_file_pwd"] = private_key_file_pwd
113+
else:
114+
# Fall back to password authentication
115+
password = os.getenv("SNOWFLAKE_PAT")
116+
if not password:
117+
raise ValueError(
118+
"Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set"
119+
)
120+
conn_params["password"] = password
121+
122+
return snowflake.connector.connect(**conn_params)
91123

92124

93125
def execute_snowflake_query(issue):

tests/test_downstream_issue.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime, timezone
2+
import os
23
from typing import Any, Optional
34
import unittest
45
import unittest.mock as mock
@@ -1758,6 +1759,15 @@ def test_remove_diacritics(self):
17581759
self.assertEqual(actual, expected)
17591760

17601761
@mock.patch(PATH + "snowflake.connector.connect")
1762+
@mock.patch.dict(
1763+
os.environ,
1764+
{
1765+
"SNOWFLAKE_ACCOUNT": "test_account",
1766+
"SNOWFLAKE_USER": "test_user",
1767+
"SNOWFLAKE_ROLE": "test_role",
1768+
"SNOWFLAKE_PAT": "fake_password",
1769+
},
1770+
)
17611771
def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect):
17621772
"""Test execute_snowflake_query function."""
17631773
# Create a mock issue
@@ -1770,8 +1780,111 @@ def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect):
17701780
)
17711781
# Assert the function was called correctly
17721782
mock_snowflake_connect.assert_called_once()
1783+
# Verify password authentication is used
1784+
call_args = mock_snowflake_connect.call_args[1]
1785+
self.assertEqual(call_args["password"], os.getenv("SNOWFLAKE_PAT"))
1786+
self.assertNotIn("authenticator", call_args)
1787+
self.assertNotIn("private_key_file", call_args)
1788+
mock_cursor.fetchall.assert_called_once()
1789+
mock_cursor.close.assert_called_once()
1790+
# Assert the result
1791+
self.assertEqual(result, mock_cursor.fetchall.return_value)
1792+
1793+
@mock.patch(PATH + "snowflake.connector.connect")
1794+
@mock.patch.dict(
1795+
os.environ,
1796+
{
1797+
"SNOWFLAKE_ACCOUNT": "test_account",
1798+
"SNOWFLAKE_USER": "test_user",
1799+
"SNOWFLAKE_ROLE": "test_role",
1800+
"SNOWFLAKE_PRIVATE_KEY_FILE": "test_key.pem",
1801+
},
1802+
)
1803+
@mock.patch("os.path.exists")
1804+
def test_execute_snowflake_query_with_jwt_auth(
1805+
self, mock_exists, mock_snowflake_connect
1806+
):
1807+
"""Test execute_snowflake_query with JWT authentication."""
1808+
mock_exists.return_value = True
1809+
# Create a mock issue
1810+
mock_issue = MagicMock()
1811+
mock_issue.url = "https://github.com/test/repo/issues/1"
1812+
# Call the function
1813+
result = d.execute_snowflake_query(mock_issue)
1814+
mock_cursor = (
1815+
mock_snowflake_connect.return_value.__enter__.return_value.cursor.return_value
1816+
)
1817+
# Assert the function was called correctly
1818+
mock_snowflake_connect.assert_called_once()
1819+
# Verify JWT authentication is used
1820+
call_args = mock_snowflake_connect.call_args[1]
1821+
self.assertEqual(call_args["authenticator"], "SNOWFLAKE_JWT")
1822+
self.assertEqual(
1823+
call_args["private_key_file"], os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE")
1824+
)
1825+
self.assertNotIn("password", call_args)
1826+
self.assertNotIn("private_key_file_pwd", call_args)
1827+
mock_cursor.execute.assert_called_once()
1828+
mock_cursor.fetchall.assert_called_once()
1829+
mock_cursor.close.assert_called_once()
1830+
# Assert the result
1831+
self.assertEqual(result, mock_cursor.fetchall.return_value)
1832+
1833+
@mock.patch(PATH + "snowflake.connector.connect")
1834+
@mock.patch.dict(
1835+
os.environ,
1836+
{
1837+
"SNOWFLAKE_ACCOUNT": "test_account",
1838+
"SNOWFLAKE_USER": "test_user",
1839+
"SNOWFLAKE_ROLE": "test_role",
1840+
"SNOWFLAKE_PRIVATE_KEY_FILE": "test_key.pem",
1841+
"SNOWFLAKE_PRIVATE_KEY_FILE_PWD": "key_password",
1842+
},
1843+
)
1844+
@mock.patch("os.path.exists")
1845+
def test_execute_snowflake_query_with_jwt_auth_and_password(
1846+
self, mock_exists, mock_snowflake_connect
1847+
):
1848+
"""Test execute_snowflake_query with JWT authentication and key password."""
1849+
mock_exists.return_value = True
1850+
# Create a mock issue
1851+
mock_issue = MagicMock()
1852+
mock_issue.url = "https://github.com/test/repo/issues/1"
1853+
# Call the function
1854+
result = d.execute_snowflake_query(mock_issue)
1855+
mock_cursor = (
1856+
mock_snowflake_connect.return_value.__enter__.return_value.cursor.return_value
1857+
)
1858+
# Assert the function was called correctly
1859+
mock_snowflake_connect.assert_called_once()
1860+
# Verify JWT authentication with password is used
1861+
call_args = mock_snowflake_connect.call_args[1]
1862+
self.assertEqual(call_args["authenticator"], "SNOWFLAKE_JWT")
1863+
self.assertEqual(
1864+
call_args["private_key_file"], os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE")
1865+
)
1866+
self.assertEqual(
1867+
call_args["private_key_file_pwd"],
1868+
os.getenv("SNOWFLAKE_PRIVATE_KEY_FILE_PWD"),
1869+
)
1870+
self.assertNotIn("password", call_args)
17731871
mock_cursor.execute.assert_called_once()
17741872
mock_cursor.fetchall.assert_called_once()
17751873
mock_cursor.close.assert_called_once()
17761874
# Assert the result
17771875
self.assertEqual(result, mock_cursor.fetchall.return_value)
1876+
1877+
@mock.patch.dict(os.environ, {}, clear=True)
1878+
def test_execute_snowflake_query_no_credentials(self):
1879+
"""Test execute_snowflake_query raises error when no credentials are set."""
1880+
# Create a mock issue
1881+
mock_issue = MagicMock()
1882+
mock_issue.url = "https://github.com/test/repo/issues/1"
1883+
1884+
with self.assertRaises(ValueError) as context:
1885+
d.execute_snowflake_query(mock_issue)
1886+
1887+
self.assertIn(
1888+
"Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set",
1889+
str(context.exception),
1890+
)

0 commit comments

Comments
 (0)