Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/snowflake/connector/auth/_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def __init__(
self,
uri: str,
buf_size: int = 16384,
redirect_uri: str | None = None,
) -> None:
parsed_uri = urllib.parse.urlparse(uri)
parsed_redirect = urllib.parse.urlparse(redirect_uri) if redirect_uri else None
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.buf_size = buf_size
if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
Expand All @@ -82,7 +84,11 @@ def __init__(
else:
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

port = parsed_uri.port or 0
if parsed_redirect and self._is_local_uri(parsed_redirect):
port = parsed_redirect.port or 0
else:
port = parsed_uri.port if parsed_uri and parsed_uri.port else 0

for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1):
try:
self._socket.bind(
Expand Down Expand Up @@ -123,6 +129,30 @@ def __init__(
query=parsed_uri.query,
fragment=parsed_uri.fragment,
)
if parsed_redirect:
if self._is_local_uri(parsed_redirect) and port != parsed_redirect.port:
logger.debug(
f"Updating redirect port {parsed_redirect.port} to match the server port {port}."
)
self._redirect_uri = urllib.parse.ParseResult(
scheme=parsed_redirect.scheme,
netloc=parsed_redirect.hostname + ":" + str(port),
path=parsed_redirect.path,
params=parsed_redirect.params,
query=parsed_redirect.query,
fragment=parsed_redirect.fragment,
)
else:
self._redirect_uri = parsed_redirect

def _is_local_uri(self, parsed_redirect):
return parsed_redirect.hostname in ("localhost", "127.0.0.1")

@property
def redirect_uri(self) -> str | None:
if self._redirect_uri:
return self._redirect_uri.geturl()
return self.url

@property
def url(self) -> str:
Expand Down
23 changes: 19 additions & 4 deletions src/snowflake/connector/auth/oauth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import hashlib
import json
import logging
import os
import secrets
import socket
import time
Expand All @@ -16,7 +17,11 @@
from typing import TYPE_CHECKING, Any

from ..compat import parse_qs, urlparse, urlsplit
from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE
from ..constants import (
ENV_VAR_OAUTH_SOCKET_ADDRESS,
ENV_VAR_OAUTH_SOCKET_PORT,
OAUTH_TYPE_AUTHORIZATION_CODE,
)
from ..errorcode import (
ER_INVALID_VALUE,
ER_OAUTH_CALLBACK_ERROR,
Expand Down Expand Up @@ -117,10 +122,20 @@ def _request_tokens(
) -> (str | None, str | None):
"""Web Browser based Authentication."""
logger.debug("authenticating with OAuth authorization code flow")
with AuthHttpServer(self._redirect_uri) as callback_server:
with AuthHttpServer(
redirect_uri=self._redirect_uri,
uri=self._read_uri_from_env(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we intend to expose this only through this env variable? Or should customer be able to pass oauth_socket_address through connection params? Is it known at the start time of the application and should be only possible to override with this env var?

Anyway we may need to create a separate Jira ticket to update the docs to include this new parameter in OAuth documentation for Python Driver.

) as callback_server:
code = self._do_authorization_request(callback_server, conn)
return self._do_token_request(code, callback_server, conn)

def _read_uri_from_env(self) -> str:
oauth_socket_address = os.getenv(
ENV_VAR_OAUTH_SOCKET_ADDRESS, "http://localhost"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

localhost, not http://localhost ?

Generally, this separation of host and port in this way rather than having the address specified as a hostport seems weird to me. Can you ensure the description explains why were are doing it this way?

)
oauth_socket_port = os.getenv(ENV_VAR_OAUTH_SOCKET_PORT, "0")
return f"{oauth_socket_address}:{oauth_socket_port}"

def _check_post_requested(
self, data: list[str]
) -> tuple[str, str] | tuple[None, None]:
Expand Down Expand Up @@ -260,7 +275,7 @@ def _do_authorization_request(
connection: SnowflakeConnection,
) -> str | None:
authorization_request = self._construct_authorization_request(
callback_server.url
callback_server.redirect_uri
)
logger.debug("step 1: going to open authorization URL")
print(
Expand Down Expand Up @@ -315,7 +330,7 @@ def _do_token_request(
fields = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": callback_server.url,
"redirect_uri": callback_server.redirect_uri,
}
if self._enable_single_use_refresh_tokens:
fields["enable_single_use_refresh_tokens"] = "true"
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/connector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ class IterUnit(Enum):
# TODO: all env variables definitions should be here
ENV_VAR_PARTNER = "SF_PARTNER"
ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE"

ENV_VAR_OAUTH_SOCKET_ADDRESS = "SNOWFLAKE_OAUTH_SOCKET_ADDRESS"
ENV_VAR_OAUTH_SOCKET_PORT = "SNOWFLAKE_OAUTH_SOCKET_PORT"

_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}

Expand Down
154 changes: 152 additions & 2 deletions test/unit/test_auth_callback_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
test_response: requests.Response | None = None
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
with AuthHttpServer(
"http://127.0.0.1/test_request",
) as callback_server:

def request_callback():
nonlocal test_response
Expand Down Expand Up @@ -57,7 +59,155 @@ def request_callback():
def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
with AuthHttpServer(
"http://127.0.0.1/test_request",
) as callback_server:
block, client_socket = callback_server.receive_block(timeout=timeout)
assert block is None
assert client_socket is None


@pytest.mark.parametrize(
"socket_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"socket_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"redirect_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"redirect_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"dontwait",
["false", "true"],
)
@pytest.mark.parametrize("reuse_port", ["true", "false"])
def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port(
monkeypatch,
socket_host,
socket_port,
redirect_host,
redirect_port,
dontwait,
reuse_port,
) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
with AuthHttpServer(
uri=f"http://{socket_host}{socket_port}/test_request",
redirect_uri=f"http://{redirect_host}{redirect_port}/test_request",
) as callback_server:
assert callback_server._redirect_uri.port == callback_server.port


@pytest.mark.parametrize(
"socket_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"socket_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"redirect_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"redirect_port",
[
54321,
54320,
],
)
@pytest.mark.parametrize(
"dontwait",
["false", "true"],
)
@pytest.mark.parametrize("reuse_port", ["true", "false"])
def test_auth_callback_server_uses_redirect_uri_port_when_specified(
monkeypatch,
socket_host,
socket_port,
redirect_host,
redirect_port,
dontwait,
reuse_port,
) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
with AuthHttpServer(
uri=f"http://{socket_host}{socket_port}/test_request",
redirect_uri=f"http://{redirect_host}:{redirect_port}/test_request",
) as callback_server:
assert callback_server.port == redirect_port
assert callback_server._redirect_uri.port == redirect_port


@pytest.mark.parametrize(
"socket_host",
[
"127.0.0.1",
"localhost",
],
)
@pytest.mark.parametrize(
"socket_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"redirect_port",
[
"",
":0",
":12345",
],
)
@pytest.mark.parametrize(
"dontwait",
["false", "true"],
)
@pytest.mark.parametrize("reuse_port", ["true", "false"])
def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port(
monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port
) -> None:
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
redirect_uri = f"http://not_localhost{redirect_port}/test_request"
with AuthHttpServer(
uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri
) as callback_server:
assert callback_server.redirect_uri == redirect_uri
68 changes: 67 additions & 1 deletion test/unit/test_auth_oauth_auth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import unittest.mock as mock
from test.helpers import apply_auth_class_update_body, create_mock_auth_body
from unittest.mock import patch
from unittest.mock import PropertyMock, patch

import pytest

Expand Down Expand Up @@ -285,3 +285,69 @@ def mock_request_tokens(self, **kwargs):
assert isinstance(conn.auth_class, AuthByOauthCode)

conn.close()


@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"])
@pytest.mark.parametrize("rtr_enabled", [True, False])
def test_auth_oauth_auth_code_uses_redirect_uri(
redirect_uri, rtr_enabled: bool, omit_oauth_urls_check
):
"""Test that the redirect URI is used correctly in the OAuth authorization code flow."""
auth = AuthByOauthCode(
"app",
"clientId",
"clientSecret",
"auth_url",
"tokenRequestUrl",
redirect_uri,
"scope",
"host",
pkce_enabled=False,
enable_single_use_refresh_tokens=rtr_enabled,
)

def fake_get_request_token_response(_, fields: dict[str, str]):
if rtr_enabled:
assert fields.get("enable_single_use_refresh_tokens") == "true"
else:
assert "enable_single_use_refresh_tokens" not in fields
return ("access_token", "refresh_token")

with patch(
"snowflake.connector.auth.AuthByOauthCode._construct_authorization_request",
return_value="authorization_request",
) as mock_construct_authorization_request:
with patch(
"snowflake.connector.auth.AuthByOauthCode._receive_authorization_callback",
return_value=("code", auth._state),
):
with patch(
"snowflake.connector.auth.AuthByOauthCode._ask_authorization_callback_from_user",
return_value=("code", auth._state),
):
with patch(
"snowflake.connector.auth.AuthByOauthCode._get_request_token_response",
side_effect=fake_get_request_token_response,
) as mock_get_request_token_response:
with patch(
"snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri",
return_value=redirect_uri,
new_callable=PropertyMock,
):
auth.prepare(
conn=None,
authenticator=OAUTH_AUTHORIZATION_CODE,
service_name=None,
account="acc",
user="user",
)
mock_construct_authorization_request.assert_called_once_with(
redirect_uri
)
assert mock_get_request_token_response.call_count == 1
assert (
mock_get_request_token_response.call_args[0][1][
"redirect_uri"
]
== redirect_uri
)
Loading