@@ -1765,9 +1765,6 @@ def test_remove_diacritics(self):
17651765 "SNOWFLAKE_ACCOUNT" : "test_account" ,
17661766 "SNOWFLAKE_USER" : "test_user" ,
17671767 "SNOWFLAKE_ROLE" : "test_role" ,
1768- "SNOWFLAKE_WAREHOUSE" : "test_wh" ,
1769- "SNOWFLAKE_DATABASE" : "test_db" ,
1770- "SNOWFLAKE_SCHEMA" : "PUBLIC" ,
17711768 "SNOWFLAKE_PAT" : "fake_password" ,
17721769 },
17731770 )
@@ -1783,7 +1780,11 @@ def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect):
17831780 )
17841781 # Assert the function was called correctly
17851782 mock_snowflake_connect .assert_called_once ()
1786- mock_cursor .execute .assert_called_once ()
1783+ # Verify password authentication is used
1784+ call_args = mock_snowflake_connect .call_args [1 ]
1785+ self .assertEqual (call_args ["password" ], os .getenv ("SNOWFLAKE_PAT" ))
1786+ self .assertNotIn ("authenticator" , call_args )
1787+ self .assertNotIn ("private_key_file" , call_args )
17871788 mock_cursor .fetchall .assert_called_once ()
17881789 mock_cursor .close .assert_called_once ()
17891790 # Assert the result
@@ -1796,23 +1797,87 @@ def test_execute_snowflake_query_real_connection(self, mock_snowflake_connect):
17961797 "SNOWFLAKE_ACCOUNT" : "test_account" ,
17971798 "SNOWFLAKE_USER" : "test_user" ,
17981799 "SNOWFLAKE_ROLE" : "test_role" ,
1799- "SNOWFLAKE_WAREHOUSE" : "test_wh" ,
1800- "SNOWFLAKE_DATABASE" : "test_db" ,
1801- "SNOWFLAKE_SCHEMA" : "PUBLIC" ,
18021800 "SNOWFLAKE_PRIVATE_KEY_FILE" : "test_key.pem" ,
18031801 },
18041802 )
18051803 @mock .patch ("os.path.exists" )
1806- def test_get_snowflake_conn_jwt_auth (self , mock_exists , mock_connect ):
1807- """Test get_snowflake_conn with JWT authentication."""
1804+ def test_execute_snowflake_query_with_jwt_auth (self , mock_exists , mock_snowflake_connect ):
1805+ """Test execute_snowflake_query with JWT authentication."""
18081806 mock_exists .return_value = True
1809- mock_connect .return_value = MagicMock ()
1810-
1811- conn = d .get_snowflake_conn ()
1807+ # Create a mock issue
1808+ mock_issue = MagicMock ()
1809+ mock_issue .url = "https://github.com/test/repo/issues/1"
1810+ # Call the function
1811+ result = d .execute_snowflake_query (mock_issue )
1812+ mock_cursor = (
1813+ mock_snowflake_connect .return_value .__enter__ .return_value .cursor .return_value
1814+ )
1815+ # Assert the function was called correctly
1816+ mock_snowflake_connect .assert_called_once ()
1817+ # Verify JWT authentication is used
1818+ call_args = mock_snowflake_connect .call_args [1 ]
1819+ self .assertEqual (call_args ["authenticator" ], "SNOWFLAKE_JWT" )
1820+ self .assertEqual (call_args ["private_key_file" ], os .getenv ("SNOWFLAKE_PRIVATE_KEY_FILE" ))
1821+ self .assertNotIn ("password" , call_args )
1822+ self .assertNotIn ("private_key_file_pwd" , call_args )
1823+ mock_cursor .execute .assert_called_once ()
1824+ mock_cursor .fetchall .assert_called_once ()
1825+ mock_cursor .close .assert_called_once ()
1826+ # Assert the result
1827+ self .assertEqual (result , mock_cursor .fetchall .return_value )
18121828
1813- # Verify connect was called with JWT authenticator
1814- mock_connect .assert_called_once ()
1815- call_args = mock_connect .call_args [1 ]
1829+ @mock .patch (PATH + "snowflake.connector.connect" )
1830+ @mock .patch .dict (
1831+ os .environ ,
1832+ {
1833+ "SNOWFLAKE_ACCOUNT" : "test_account" ,
1834+ "SNOWFLAKE_USER" : "test_user" ,
1835+ "SNOWFLAKE_ROLE" : "test_role" ,
1836+ "SNOWFLAKE_PRIVATE_KEY_FILE" : "test_key.pem" ,
1837+ "SNOWFLAKE_PRIVATE_KEY_FILE_PWD" : "key_password" ,
1838+ },
1839+ )
1840+ @mock .patch ("os.path.exists" )
1841+ def test_execute_snowflake_query_with_jwt_auth_and_password (
1842+ self , mock_exists , mock_snowflake_connect
1843+ ):
1844+ """Test execute_snowflake_query with JWT authentication and key password."""
1845+ mock_exists .return_value = True
1846+ # Create a mock issue
1847+ mock_issue = MagicMock ()
1848+ mock_issue .url = "https://github.com/test/repo/issues/1"
1849+ # Call the function
1850+ result = d .execute_snowflake_query (mock_issue )
1851+ mock_cursor = (
1852+ mock_snowflake_connect .return_value .__enter__ .return_value .cursor .return_value
1853+ )
1854+ # Assert the function was called correctly
1855+ mock_snowflake_connect .assert_called_once ()
1856+ # Verify JWT authentication with password is used
1857+ call_args = mock_snowflake_connect .call_args [1 ]
18161858 self .assertEqual (call_args ["authenticator" ], "SNOWFLAKE_JWT" )
1817- self .assertEqual (call_args ["private_key_file" ], "test_key.pem" )
1859+ self .assertEqual (call_args ["private_key_file" ], os .getenv ("SNOWFLAKE_PRIVATE_KEY_FILE" ))
1860+ self .assertEqual (
1861+ call_args ["private_key_file_pwd" ], os .getenv ("SNOWFLAKE_PRIVATE_KEY_FILE_PWD" )
1862+ )
18181863 self .assertNotIn ("password" , call_args )
1864+ mock_cursor .execute .assert_called_once ()
1865+ mock_cursor .fetchall .assert_called_once ()
1866+ mock_cursor .close .assert_called_once ()
1867+ # Assert the result
1868+ self .assertEqual (result , mock_cursor .fetchall .return_value )
1869+
1870+ @mock .patch .dict (os .environ , {}, clear = True )
1871+ def test_execute_snowflake_query_no_credentials (self ):
1872+ """Test execute_snowflake_query raises error when no credentials are set."""
1873+ # Create a mock issue
1874+ mock_issue = MagicMock ()
1875+ mock_issue .url = "https://github.com/test/repo/issues/1"
1876+
1877+ with self .assertRaises (ValueError ) as context :
1878+ d .execute_snowflake_query (mock_issue )
1879+
1880+ self .assertIn (
1881+ "Either SNOWFLAKE_PRIVATE_KEY_FILE or SNOWFLAKE_PAT must be set" ,
1882+ str (context .exception ),
1883+ )
0 commit comments