diff --git a/src/snowflake/connector/auth/_http_server.py b/src/snowflake/connector/auth/_http_server.py index a11662f25b..58c9e5bf0b 100644 --- a/src/snowflake/connector/auth/_http_server.py +++ b/src/snowflake/connector/auth/_http_server.py @@ -70,8 +70,10 @@ def __init__( self, uri: str, buf_size: int = 16384, + redirect_uri: str | None = None, ) -> None: parsed_uri = urllib.parse.urlparse(uri) + parsed_redirect = urllib.parse.urlparse(redirect_uri) if redirect_uri else None self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.buf_size = buf_size if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true": @@ -82,7 +84,11 @@ def __init__( else: self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - port = parsed_uri.port or 0 + if parsed_redirect and self._is_local_uri(parsed_redirect): + port = parsed_redirect.port or 0 + else: + port = parsed_uri.port if parsed_uri and parsed_uri.port else 0 + for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1): try: self._socket.bind( @@ -123,6 +129,30 @@ def __init__( query=parsed_uri.query, fragment=parsed_uri.fragment, ) + if parsed_redirect: + if self._is_local_uri(parsed_redirect) and port != parsed_redirect.port: + logger.debug( + f"Updating redirect port {parsed_redirect.port} to match the server port {port}." + ) + self._redirect_uri = urllib.parse.ParseResult( + scheme=parsed_redirect.scheme, + netloc=parsed_redirect.hostname + ":" + str(port), + path=parsed_redirect.path, + params=parsed_redirect.params, + query=parsed_redirect.query, + fragment=parsed_redirect.fragment, + ) + else: + self._redirect_uri = parsed_redirect + + def _is_local_uri(self, parsed_redirect): + return parsed_redirect.hostname in ("localhost", "127.0.0.1") + + @property + def redirect_uri(self) -> str | None: + if self._redirect_uri: + return self._redirect_uri.geturl() + return self.url @property def url(self) -> str: diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py index a5aaf31fb9..59a725e233 100644 --- a/src/snowflake/connector/auth/oauth_code.py +++ b/src/snowflake/connector/auth/oauth_code.py @@ -8,6 +8,7 @@ import hashlib import json import logging +import os import secrets import socket import time @@ -16,7 +17,11 @@ from typing import TYPE_CHECKING, Any from ..compat import parse_qs, urlparse, urlsplit -from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE +from ..constants import ( + ENV_VAR_OAUTH_SOCKET_ADDRESS, + ENV_VAR_OAUTH_SOCKET_PORT, + OAUTH_TYPE_AUTHORIZATION_CODE, +) from ..errorcode import ( ER_INVALID_VALUE, ER_OAUTH_CALLBACK_ERROR, @@ -117,10 +122,20 @@ def _request_tokens( ) -> (str | None, str | None): """Web Browser based Authentication.""" logger.debug("authenticating with OAuth authorization code flow") - with AuthHttpServer(self._redirect_uri) as callback_server: + with AuthHttpServer( + redirect_uri=self._redirect_uri, + uri=self._read_uri_from_env(), + ) as callback_server: code = self._do_authorization_request(callback_server, conn) return self._do_token_request(code, callback_server, conn) + def _read_uri_from_env(self) -> str: + oauth_socket_address = os.getenv( + ENV_VAR_OAUTH_SOCKET_ADDRESS, "http://localhost" + ) + oauth_socket_port = os.getenv(ENV_VAR_OAUTH_SOCKET_PORT, "0") + return f"{oauth_socket_address}:{oauth_socket_port}" + def _check_post_requested( self, data: list[str] ) -> tuple[str, str] | tuple[None, None]: @@ -260,7 +275,7 @@ def _do_authorization_request( connection: SnowflakeConnection, ) -> str | None: authorization_request = self._construct_authorization_request( - callback_server.url + callback_server.redirect_uri ) logger.debug("step 1: going to open authorization URL") print( @@ -315,7 +330,7 @@ def _do_token_request( fields = { "grant_type": "authorization_code", "code": code, - "redirect_uri": callback_server.url, + "redirect_uri": callback_server.redirect_uri, } if self._enable_single_use_refresh_tokens: fields["enable_single_use_refresh_tokens"] = "true" diff --git a/src/snowflake/connector/constants.py b/src/snowflake/connector/constants.py index 47f07b9eb9..1f89064c9c 100644 --- a/src/snowflake/connector/constants.py +++ b/src/snowflake/connector/constants.py @@ -441,7 +441,8 @@ class IterUnit(Enum): # TODO: all env variables definitions should be here ENV_VAR_PARTNER = "SF_PARTNER" ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE" - +ENV_VAR_OAUTH_SOCKET_ADDRESS = "SNOWFLAKE_OAUTH_SOCKET_ADDRESS" +ENV_VAR_OAUTH_SOCKET_PORT = "SNOWFLAKE_OAUTH_SOCKET_PORT" _DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"} diff --git a/test/unit/test_auth_callback_server.py b/test/unit/test_auth_callback_server.py index bf03a8d5f6..5a33ec0a80 100644 --- a/test/unit/test_auth_callback_server.py +++ b/test/unit/test_auth_callback_server.py @@ -24,7 +24,9 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) test_response: requests.Response | None = None - with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + with AuthHttpServer( + "http://127.0.0.1/test_request", + ) as callback_server: def request_callback(): nonlocal test_response @@ -57,7 +59,155 @@ def request_callback(): def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None: monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) - with AuthHttpServer("http://127.0.0.1/test_request") as callback_server: + with AuthHttpServer( + "http://127.0.0.1/test_request", + ) as callback_server: block, client_socket = callback_server.receive_block(timeout=timeout) assert block is None assert client_socket is None + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port( + monkeypatch, + socket_host, + socket_port, + redirect_host, + redirect_port, + dontwait, + reuse_port, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", + redirect_uri=f"http://{redirect_host}{redirect_port}/test_request", + ) as callback_server: + assert callback_server._redirect_uri.port == callback_server.port + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + 54321, + 54320, + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_uses_redirect_uri_port_when_specified( + monkeypatch, + socket_host, + socket_port, + redirect_host, + redirect_port, + dontwait, + reuse_port, +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", + redirect_uri=f"http://{redirect_host}:{redirect_port}/test_request", + ) as callback_server: + assert callback_server.port == redirect_port + assert callback_server._redirect_uri.port == redirect_port + + +@pytest.mark.parametrize( + "socket_host", + [ + "127.0.0.1", + "localhost", + ], +) +@pytest.mark.parametrize( + "socket_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "redirect_port", + [ + "", + ":0", + ":12345", + ], +) +@pytest.mark.parametrize( + "dontwait", + ["false", "true"], +) +@pytest.mark.parametrize("reuse_port", ["true", "false"]) +def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port( + monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port +) -> None: + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port) + monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait) + redirect_uri = f"http://not_localhost{redirect_port}/test_request" + with AuthHttpServer( + uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri + ) as callback_server: + assert callback_server.redirect_uri == redirect_uri diff --git a/test/unit/test_auth_oauth_auth_code.py b/test/unit/test_auth_oauth_auth_code.py index 76894791cc..2e52ed5f87 100644 --- a/test/unit/test_auth_oauth_auth_code.py +++ b/test/unit/test_auth_oauth_auth_code.py @@ -5,7 +5,7 @@ import unittest.mock as mock from test.helpers import apply_auth_class_update_body, create_mock_auth_body -from unittest.mock import patch +from unittest.mock import PropertyMock, patch import pytest @@ -285,3 +285,69 @@ def mock_request_tokens(self, **kwargs): assert isinstance(conn.auth_class, AuthByOauthCode) conn.close() + + +@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"]) +@pytest.mark.parametrize("rtr_enabled", [True, False]) +def test_auth_oauth_auth_code_uses_redirect_uri( + redirect_uri, rtr_enabled: bool, omit_oauth_urls_check +): + """Test that the redirect URI is used correctly in the OAuth authorization code flow.""" + auth = AuthByOauthCode( + "app", + "clientId", + "clientSecret", + "auth_url", + "tokenRequestUrl", + redirect_uri, + "scope", + "host", + pkce_enabled=False, + enable_single_use_refresh_tokens=rtr_enabled, + ) + + def fake_get_request_token_response(_, fields: dict[str, str]): + if rtr_enabled: + assert fields.get("enable_single_use_refresh_tokens") == "true" + else: + assert "enable_single_use_refresh_tokens" not in fields + return ("access_token", "refresh_token") + + with patch( + "snowflake.connector.auth.AuthByOauthCode._construct_authorization_request", + return_value="authorization_request", + ) as mock_construct_authorization_request: + with patch( + "snowflake.connector.auth.AuthByOauthCode._receive_authorization_callback", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.auth.AuthByOauthCode._ask_authorization_callback_from_user", + return_value=("code", auth._state), + ): + with patch( + "snowflake.connector.auth.AuthByOauthCode._get_request_token_response", + side_effect=fake_get_request_token_response, + ) as mock_get_request_token_response: + with patch( + "snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri", + return_value=redirect_uri, + new_callable=PropertyMock, + ): + auth.prepare( + conn=None, + authenticator=OAUTH_AUTHORIZATION_CODE, + service_name=None, + account="acc", + user="user", + ) + mock_construct_authorization_request.assert_called_once_with( + redirect_uri + ) + assert mock_get_request_token_response.call_count == 1 + assert ( + mock_get_request_token_response.call_args[0][1][ + "redirect_uri" + ] + == redirect_uri + )