Skip to content

Commit 6f50a50

Browse files
Make the changes backward compatible
1 parent 72e0e12 commit 6f50a50

File tree

4 files changed

+44
-32
lines changed

4 files changed

+44
-32
lines changed

src/snowflake/connector/auth/_http_server.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ class AuthHttpServer:
6969
def __init__(
7070
self,
7171
uri: str,
72-
redirect_uri: str,
7372
buf_size: int = 16384,
73+
redirect_uri: str | None = None,
7474
) -> None:
7575
parsed_uri = urllib.parse.urlparse(uri)
76-
parsed_redirect = urllib.parse.urlparse(redirect_uri)
76+
parsed_redirect = urllib.parse.urlparse(redirect_uri) if redirect_uri else None
7777
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
7878
self.buf_size = buf_size
7979
if os.getenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "False").lower() == "true":
@@ -84,10 +84,11 @@ def __init__(
8484
else:
8585
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
8686

87-
if parsed_redirect.hostname in ("localhost", "127.0.0.1"):
87+
if parsed_redirect and self._is_local_uri(parsed_redirect):
8888
port = parsed_redirect.port or 0
8989
else:
90-
port = parsed_uri.port or 0
90+
port = parsed_uri.port if parsed_uri and parsed_uri.port else 0
91+
9192
for attempt in range(1, self.DEFAULT_MAX_ATTEMPTS + 1):
9293
try:
9394
self._socket.bind(
@@ -128,27 +129,30 @@ def __init__(
128129
query=parsed_uri.query,
129130
fragment=parsed_uri.fragment,
130131
)
131-
if (
132-
parsed_redirect.hostname in ("localhost", "127.0.0.1")
133-
and port != parsed_redirect.port
134-
):
135-
logger.debug(
136-
f"Updating redirect port {parsed_redirect.port} to match the server port {port}."
137-
)
138-
self._redirect_uri = urllib.parse.ParseResult(
139-
scheme=parsed_redirect.scheme,
140-
netloc=parsed_redirect.hostname + ":" + str(port),
141-
path=parsed_redirect.path,
142-
params=parsed_redirect.params,
143-
query=parsed_redirect.query,
144-
fragment=parsed_redirect.fragment,
145-
)
146-
else:
147-
self._redirect_uri = parsed_redirect
132+
if parsed_redirect:
133+
if self._is_local_uri(parsed_redirect) and port != parsed_redirect.port:
134+
logger.debug(
135+
f"Updating redirect port {parsed_redirect.port} to match the server port {port}."
136+
)
137+
self._redirect_uri = urllib.parse.ParseResult(
138+
scheme=parsed_redirect.scheme,
139+
netloc=parsed_redirect.hostname + ":" + str(port),
140+
path=parsed_redirect.path,
141+
params=parsed_redirect.params,
142+
query=parsed_redirect.query,
143+
fragment=parsed_redirect.fragment,
144+
)
145+
else:
146+
self._redirect_uri = parsed_redirect
147+
148+
def _is_local_uri(self, parsed_redirect):
149+
return parsed_redirect.hostname in ("localhost", "127.0.0.1")
148150

149151
@property
150-
def redirect_uri(self) -> str:
151-
return self._redirect_uri.geturl()
152+
def redirect_uri(self) -> str | None:
153+
if self._redirect_uri:
154+
return self._redirect_uri.geturl()
155+
return self.url
152156

153157
@property
154158
def url(self) -> str:

src/snowflake/connector/auth/oauth_code.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from typing import TYPE_CHECKING, Any
1818

1919
from ..compat import parse_qs, urlparse, urlsplit
20-
from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE
20+
from ..constants import (
21+
ENV_VAR_OAUTH_SOCKET_ADDRESS,
22+
ENV_VAR_OAUTH_SOCKET_PORT,
23+
OAUTH_TYPE_AUTHORIZATION_CODE,
24+
)
2125
from ..errorcode import (
2226
ER_INVALID_VALUE,
2327
ER_OAUTH_CALLBACK_ERROR,
@@ -119,14 +123,19 @@ def _request_tokens(
119123
"""Web Browser based Authentication."""
120124
logger.debug("authenticating with OAuth authorization code flow")
121125
with AuthHttpServer(
122-
uri=os.environ.get("SNOWFLAKE_OAUTH_SOCKET_ADDRESS", "http://localhost")
123-
+ ":"
124-
+ os.environ.get("SNOWFLAKE_OAUTH_SOCKET_PORT", "0"),
125126
redirect_uri=self._redirect_uri,
127+
uri=self._read_uri_from_env(),
126128
) as callback_server:
127129
code = self._do_authorization_request(callback_server, conn)
128130
return self._do_token_request(code, callback_server, conn)
129131

132+
def _read_uri_from_env(self) -> str:
133+
oauth_socket_address = os.getenv(
134+
ENV_VAR_OAUTH_SOCKET_ADDRESS, "http://localhost"
135+
)
136+
oauth_socket_port = os.getenv(ENV_VAR_OAUTH_SOCKET_PORT, "0")
137+
return f"{oauth_socket_address}:{oauth_socket_port}"
138+
130139
def _check_post_requested(
131140
self, data: list[str]
132141
) -> tuple[str, str] | tuple[None, None]:

src/snowflake/connector/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,8 @@ class IterUnit(Enum):
441441
# TODO: all env variables definitions should be here
442442
ENV_VAR_PARTNER = "SF_PARTNER"
443443
ENV_VAR_TEST_MODE = "SNOWFLAKE_TEST_MODE"
444-
444+
ENV_VAR_OAUTH_SOCKET_ADDRESS = "SNOWFLAKE_OAUTH_SOCKET_ADDRESS"
445+
ENV_VAR_OAUTH_SOCKET_PORT = "SNOWFLAKE_OAUTH_SOCKET_PORT"
445446

446447
_DOMAIN_NAME_MAP = {_DEFAULT_HOSTNAME_TLD: "GLOBAL", _CHINA_HOSTNAME_TLD: "CHINA"}
447448

test/unit/test_auth_callback_server.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No
2525
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
2626
test_response: requests.Response | None = None
2727
with AuthHttpServer(
28-
uri="http://127.0.0.1/test_request",
29-
redirect_uri="http://127.0.0.1/test_request",
28+
"http://127.0.0.1/test_request",
3029
) as callback_server:
3130

3231
def request_callback():
@@ -61,8 +60,7 @@ def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> No
6160
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
6261
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
6362
with AuthHttpServer(
64-
uri="http://127.0.0.1/test_request",
65-
redirect_uri="http://127.0.0.1/test_request",
63+
"http://127.0.0.1/test_request",
6664
) as callback_server:
6765
block, client_socket = callback_server.receive_block(timeout=timeout)
6866
assert block is None

0 commit comments

Comments
 (0)