diff --git a/src/snowflake/connector/auth/_http_server.py b/src/snowflake/connector/auth/_http_server.py index a11662f25..5c9427cbd 100644 --- a/src/snowflake/connector/auth/_http_server.py +++ b/src/snowflake/connector/auth/_http_server.py @@ -69,9 +69,11 @@ class AuthHttpServer: def __init__( self, uri: str, + redirect_uri: str, buf_size: int = 16384, ) -> None: parsed_uri = urllib.parse.urlparse(uri) + parsed_redirect = urllib.parse.urlparse(redirect_uri) self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.buf_size = buf_size if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true": @@ -82,7 +84,10 @@ def __init__( else: self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - port = parsed_uri.port or 0 + if parsed_redirect.hostname in ("localhost", "127.0.0.1"): + port = parsed_redirect.port or 0 + else: + port = parsed_uri.port or 0 for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1): try: self._socket.bind( @@ -123,6 +128,27 @@ def __init__( query=parsed_uri.query, fragment=parsed_uri.fragment, ) + if ( + parsed_redirect.hostname in ("localhost", "127.0.0.1") + and port != parsed_redirect.port + ): + logger.debug( + f"Updating redirect port {parsed_redirect.port} to match the server port {port}." + ) + self._redirect_uri = urllib.parse.ParseResult( + scheme=parsed_redirect.scheme, + netloc=parsed_redirect.hostname + ":" + str(port), + path=parsed_redirect.path, + params=parsed_redirect.params, + query=parsed_redirect.query, + fragment=parsed_redirect.fragment, + ) + else: + self._redirect_uri = parsed_redirect + + @property + def redirect_uri(self) -> str: + return self._redirect_uri.geturl() @property def url(self) -> str: diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py index 1c0c41eb6..7735dce8f 100644 --- a/src/snowflake/connector/auth/oauth_code.py +++ b/src/snowflake/connector/auth/oauth_code.py @@ -8,6 +8,7 @@ import hashlib import json import logging +import os import secrets import socket import time @@ -117,7 +118,12 @@ def _request_tokens( ) -> (str | None, str | None): """Web Browser based Authentication.""" logger.debug("authenticating with OAuth authorization code flow") - with AuthHttpServer(self._redirect_uri) as callback_server: + with AuthHttpServer( + uri=os.environ.get("SNOWFLAKE_OAUTH_SOCKET_ADDRESS", "http://localhost") + + ":" + + os.environ.get("SNOWFLAKE_OAUTH_SOCKET_PORT", "0"), + redirect_uri=self._redirect_uri, + ) as callback_server: code = self._do_authorization_request(callback_server, conn) return self._do_token_request(code, callback_server, conn) @@ -260,7 +266,7 @@ def _do_authorization_request( connection: SnowflakeConnection, ) -> str | None: authorization_request = self._construct_authorization_request( - callback_server.url + callback_server.redirect_uri ) logger.debug("step 1: going to open authorization URL") print( @@ -314,7 +320,7 @@ def _do_token_request( fields = { "grant_type": "authorization_code", "code": code, - "redirect_uri": callback_server.url, + "redirect_uri": callback_server.redirect_uri, } if self._enable_single_use_refresh_tokens: fields["enable_single_use_refresh_tokens"] = "true" diff --git a/test/unit/test_auth_callback_server.py b/test/unit/test_auth_callback_server.py index bf03a8d5f..17b8444b9 100644 --- a/test/unit/test_auth_callback_server.py +++ b/test/unit/test_auth_callback_server.py @@ -24,7 +24,10 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) test_response: requests.Response | None = None - with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + with AuthHttpServer( + uri="http://127.0.0.1/test_request", + redirect_uri="http://127.0.0.1/test_request", + ) as callback_server: def request_callback(): nonlocal test_response @@ -57,7 +60,156 @@ def request_callback(): def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) - with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + with AuthHttpServer( + uri="http://127.0.0.1/test_request", + redirect_uri="http://127.0.0.1/test_request", + ) as callback_server: block, client_socket = callback_server.receive_block(timeout=timeout) assert block is None assert client_socket is None + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port( + monkeypatch, + socket_host, + socket_port, + redirect_host, + redirect_port, + dontwait, + reuse_port, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", + redirect_uri=f"http://{redirect_host}{redirect_port}/test_request", + ) as callback_server: + assert callback_server._redirect_uri.port == callback_server.port + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + 54321, + 54320, + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_uses_redirect_uri_port_when_specified( + monkeypatch, + socket_host, + socket_port, + redirect_host, + redirect_port, + dontwait, + reuse_port, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", + redirect_uri=f"http://{redirect_host}:{redirect_port}/test_request", + ) as callback_server: + assert callback_server.port == redirect_port + assert callback_server._redirect_uri.port == redirect_port + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port( + monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + redirect_uri = f"http://not_localhost{redirect_port}/test_request" + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri + ) as callback_server: + assert callback_server.redirect_uri == redirect_uri diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py index 25e8b6939..540d24e0c 100644 --- a/test/unit/test_auth_oauth_auth_code.py +++ b/test/unit/test_auth_oauth_auth_code.py @@ -4,7 +4,7 @@ # import unittest.mock as mock -from unittest.mock import patch +from unittest.mock import PropertyMock, patch import pytest @@ -209,3 +209,69 @@ def assert_initialized_correctly() -> None: assert_initialized_correctly() else: assert_initialized_correctly() + + +@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"]) +@pytest.mark.parametrize("rtr_enabled", [True, False]) +def test_auth_oauth_auth_code_uses_redirect_uri( + redirect_uri, rtr_enabled: bool, omit_oauth_urls_check +): + """Test that the redirect URI is used correctly in the OAuth authorization code flow.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + redirect_uri, + "scope", + "host", + pkce_enabled=False, + enable_single_use_refresh_tokens=rtr_enabled, + ) + + def fake_get_request_token_response(_, fields: dict[str, str]): + if rtr_enabled: + assert fields.get("enable_single_use_refresh_tokens") == "true" + else: + assert "enable_single_use_refresh_tokens" not in fields + return ("access_token", "refresh_token") + + with patch( + "snowflake.connector.auth.AuthByOauthCode._construct_authorization_request", + return_value="authorization_request", + ) as mock_construct_authorization_request: + with patch( + "snowflake.connector.auth.AuthByOauthCode._receive_authorization_callback", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.auth.AuthByOauthCode._ask_authorization_callback_from_user", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.auth.AuthByOauthCode._get_request_token_response", + side_effect=fake_get_request_token_response, + ) as mock_get_request_token_response: + with patch( + "snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri", + return_value=redirect_uri, + new_callable=PropertyMock, + ): + auth.prepare( + conn=None, + authenticator=OAUTH_AUTHORIZATION_CODE, + service_name=None, + account="acc", + user="user", + ) + mock_construct_authorization_request.assert_called_once_with( + redirect_uri + ) + assert mock_get_request_token_response.call_count == 1 + assert ( + mock_get_request_token_response.call_args[0][1][ + "redirect_uri" + ] + == redirect_uri + ) diff --git a/test/unit/test_oauth_token.py b/test/unit/test_oauth_token.py index cae246545..a2b433937 100644 --- a/test/unit/test_oauth_token.py +++ b/test/unit/test_oauth_token.py @@ -142,6 +142,7 @@ def test_oauth_code_successful_flow( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_oauth_authorization_code_dir / "successful_flow.json" @@ -184,6 +185,7 @@ def test_oauth_code_invalid_state( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_oauth_authorization_code_dir / "invalid_state_error.json" @@ -219,6 +221,7 @@ def test_oauth_code_scope_error( monkeypatch, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_oauth_authorization_code_dir / "invalid_scope_error.json" @@ -255,6 +258,7 @@ def test_oauth_code_token_request_error( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") with WiremockClient() as wiremock_client: wiremock_client.import_mapping( @@ -293,6 +297,7 @@ def test_oauth_code_browser_timeout( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_oauth_authorization_code_dir @@ -334,6 +339,7 @@ def test_oauth_code_custom_urls( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_oauth_authorization_code_dir / "external_idp_custom_urls.json" @@ -377,6 +383,7 @@ def test_oauth_code_local_application_custom_urls_successful_flow( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_oauth_authorization_code_dir @@ -421,6 +428,7 @@ def test_oauth_code_successful_refresh_token_flow( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_generic_mappings_dir / "snowflake_login_failed.json" @@ -481,6 +489,7 @@ def test_oauth_code_expired_refresh_token_flow( omit_oauth_urls_check, ) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true") + monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009") wiremock_client.import_mapping( wiremock_generic_mappings_dir / "snowflake_login_failed.json" @@ -562,6 +571,10 @@ def test_client_creds_successful_flow( monkeypatch, temp_cache, ) -> None: + monkeypatch.setenv( + "SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port + ) + wiremock_client.import_mapping( wiremock_oauth_client_creds_dir / "successful_flow.json" ) @@ -612,6 +625,10 @@ def test_client_creds_token_request_error( wiremock_generic_mappings_dir, monkeypatch, ) -> None: + monkeypatch.setenv( + "SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port + ) + wiremock_client.import_mapping( wiremock_oauth_client_creds_dir / "token_request_error.json" ) @@ -653,6 +670,10 @@ def test_client_creds_expired_refresh_token_flow( monkeypatch, temp_cache, ) -> None: + monkeypatch.setenv( + "SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port + ) + wiremock_client.import_mapping( wiremock_generic_mappings_dir / "snowflake_login_failed.json" )