Skip to content

Separate server and redirect URIs in AuthHttpServer #2400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/snowflake/connector/auth/_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ class AuthHttpServer:
def __init__(
self,
uri: str,
redirect_uri: str,
buf_size: int = 16384,
) -> None:
parsed_uri = urllib.parse.urlparse(uri)
parsed_redirect = urllib.parse.urlparse(redirect_uri)
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.buf_size = buf_size
if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
Expand All @@ -82,7 +84,10 @@ def __init__(
else:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

port = parsed_uri.port or 0
if parsed_redirect.hostname in ("localhost", "127.0.0.1"):
port = parsed_redirect.port or 0
else:
port = parsed_uri.port or 0
for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1):
try:
self._socket.bind(
Expand Down Expand Up @@ -123,6 +128,27 @@ def __init__(
query=parsed_uri.query,
fragment=parsed_uri.fragment,
)
if (
parsed_redirect.hostname in ("localhost", "127.0.0.1")
and port != parsed_redirect.port
):
logger.debug(
f"Updating redirect port {parsed_redirect.port} to match the server port {port}."
)
self._redirect_uri = urllib.parse.ParseResult(
scheme=parsed_redirect.scheme,
netloc=parsed_redirect.hostname + ":" + str(port),
path=parsed_redirect.path,
params=parsed_redirect.params,
query=parsed_redirect.query,
fragment=parsed_redirect.fragment,
)
else:
self._redirect_uri = parsed_redirect

@property
def redirect_uri(self) -> str:
return self._redirect_uri.geturl()

@property
def url(self) -> str:
Expand Down
12 changes: 9 additions & 3 deletions src/snowflake/connector/auth/oauth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import hashlib
import json
import logging
import os
import secrets
import socket
import time
Expand Down Expand Up @@ -117,7 +118,12 @@ def _request_tokens(
) -> (str | None, str | None):
"""Web Browser based Authentication."""
logger.debug("authenticating with OAuth authorization code flow")
with AuthHttpServer(self._redirect_uri) as callback_server:
with AuthHttpServer(
uri=os.environ.get("SNOWFLAKE_OAUTH_SOCKET_ADDRESS", "http://localhost")
+ ":"
+ os.environ.get("SNOWFLAKE_OAUTH_SOCKET_PORT", "0"),
Comment on lines +122 to +124

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we use an f string here?

redirect_uri=self._redirect_uri,
) as callback_server:
code = self._do_authorization_request(callback_server, conn)
return self._do_token_request(code, callback_server, conn)

Expand Down Expand Up @@ -260,7 +266,7 @@ def _do_authorization_request(
connection: SnowflakeConnection,
) -> str | None:
authorization_request = self._construct_authorization_request(
callback_server.url
callback_server.redirect_uri
)
logger.debug("step 1: going to open authorization URL")
print(
Expand Down Expand Up @@ -314,7 +320,7 @@ def _do_token_request(
fields = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": callback_server.url,
"redirect_uri": callback_server.redirect_uri,
}
if self._enable_single_use_refresh_tokens:
fields["enable_single_use_refresh_tokens"] = "true"
Expand Down
156 changes: 154 additions & 2 deletions test/unit/test_auth_callback_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
test_response: requests.Response | None = None
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
with AuthHttpServer(
uri="http://127.0.0.1/test_request",
redirect_uri="http://127.0.0.1/test_request",
) as callback_server:

def request_callback():
nonlocal test_response
Expand Down Expand Up @@ -57,7 +60,156 @@ def request_callback():
def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
with AuthHttpServer(
uri="http://127.0.0.1/test_request",
redirect_uri="http://127.0.0.1/test_request",
) as callback_server:
block, client_socket = callback_server.receive_block(timeout=timeout)
assert block is None
assert client_socket is None


@pytest.mark.parametrize(
"socket_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"socket_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"redirect_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"redirect_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"dontwait",
["false", "true"],
)
@pytest.mark.parametrize("reuse_port", ["true", "false"])
def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port(
monkeypatch,
socket_host,
socket_port,
redirect_host,
redirect_port,
dontwait,
reuse_port,
) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
with AuthHttpServer(
uri=f"http://{socket_host}{socket_port}/test_request",
redirect_uri=f"http://{redirect_host}{redirect_port}/test_request",
) as callback_server:
assert callback_server._redirect_uri.port == callback_server.port


@pytest.mark.parametrize(
"socket_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"socket_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"redirect_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"redirect_port",
[
54321,
54320,
],
)
@pytest.mark.parametrize(
"dontwait",
["false", "true"],
)
@pytest.mark.parametrize("reuse_port", ["true", "false"])
def test_auth_callback_server_uses_redirect_uri_port_when_specified(
monkeypatch,
socket_host,
socket_port,
redirect_host,
redirect_port,
dontwait,
reuse_port,
) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
with AuthHttpServer(
uri=f"http://{socket_host}{socket_port}/test_request",
redirect_uri=f"http://{redirect_host}:{redirect_port}/test_request",
) as callback_server:
assert callback_server.port == redirect_port
assert callback_server._redirect_uri.port == redirect_port


@pytest.mark.parametrize(
"socket_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"socket_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"redirect_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"dontwait",
["false", "true"],
)
@pytest.mark.parametrize("reuse_port", ["true", "false"])
def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port(
monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port
) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
redirect_uri = f"http://not_localhost{redirect_port}/test_request"
with AuthHttpServer(
uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri
) as callback_server:
assert callback_server.redirect_uri == redirect_uri
68 changes: 67 additions & 1 deletion test/unit/test_auth_oauth_auth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#

import unittest.mock as mock
from unittest.mock import patch
from unittest.mock import PropertyMock, patch

import pytest

Expand Down Expand Up @@ -209,3 +209,69 @@ def assert_initialized_correctly() -> None:
assert_initialized_correctly()
else:
assert_initialized_correctly()


@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"])
@pytest.mark.parametrize("rtr_enabled", [True, False])
def test_auth_oauth_auth_code_uses_redirect_uri(
redirect_uri, rtr_enabled: bool, omit_oauth_urls_check
):
"""Test that the redirect URI is used correctly in the OAuth authorization code flow."""
auth = AuthByOauthCode(
"app",
"clientId",
"clientSecret",
"auth_url",
"tokenRequestUrl",
redirect_uri,
"scope",
"host",
pkce_enabled=False,
enable_single_use_refresh_tokens=rtr_enabled,
)

def fake_get_request_token_response(_, fields: dict[str, str]):
if rtr_enabled:
assert fields.get("enable_single_use_refresh_tokens") == "true"
else:
assert "enable_single_use_refresh_tokens" not in fields
return ("access_token", "refresh_token")

with patch(
"snowflake.connector.auth.AuthByOauthCode._construct_authorization_request",
return_value="authorization_request",
) as mock_construct_authorization_request:
with patch(
"snowflake.connector.auth.AuthByOauthCode._receive_authorization_callback",
return_value=("code", auth._state),
):
with patch(
"snowflake.connector.auth.AuthByOauthCode._ask_authorization_callback_from_user",
return_value=("code", auth._state),
):
with patch(
"snowflake.connector.auth.AuthByOauthCode._get_request_token_response",
side_effect=fake_get_request_token_response,
) as mock_get_request_token_response:
with patch(
"snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri",
return_value=redirect_uri,
new_callable=PropertyMock,
):
auth.prepare(
conn=None,
authenticator=OAUTH_AUTHORIZATION_CODE,
service_name=None,
account="acc",
user="user",
)
mock_construct_authorization_request.assert_called_once_with(
redirect_uri
)
assert mock_get_request_token_response.call_count == 1
assert (
mock_get_request_token_response.call_args[0][1][
"redirect_uri"
]
== redirect_uri
)
Loading