Skip to content

Commit 58931a7

Browse files
committed
add unit tests
1 parent 2e4cb39 commit 58931a7

File tree

3 files changed

+193
-3
lines changed

3 files changed

+193
-3
lines changed

test/unit/test_auth_callback_server.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def test_auth_callback_success(monkeypatch, dontwait, timeout, reuse_port) -> No
2424
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
2525
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
2626
test_response: requests.Response | None = None
27-
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
27+
with AuthHttpServer(
28+
uri="http://127.0.0.1/test_request",
29+
redirect_uri="http://127.0.0.1/test_request",
30+
) as callback_server:
2831

2932
def request_callback():
3033
nonlocal test_response
@@ -57,7 +60,103 @@ def request_callback():
5760
def test_auth_callback_timeout(monkeypatch, dontwait, timeout, reuse_port) -> None:
5861
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
5962
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
60-
with AuthHttpServer("http://127.0.0.1/test_request") as callback_server:
63+
with AuthHttpServer(
64+
uri="http://127.0.0.1/test_request",
65+
redirect_uri="http://127.0.0.1/test_request",
66+
) as callback_server:
6167
block, client_socket = callback_server.receive_block(timeout=timeout)
6268
assert block is None
6369
assert client_socket is None
70+
71+
72+
@pytest.mark.parametrize(
73+
"socket_host",
74+
[
75+
"127.0.0.1",
76+
"localhost",
77+
],
78+
)
79+
@pytest.mark.parametrize(
80+
"socket_port",
81+
[
82+
"",
83+
":0",
84+
":12345",
85+
],
86+
)
87+
@pytest.mark.parametrize(
88+
"redirect_host",
89+
[
90+
"127.0.0.1",
91+
"localhost",
92+
],
93+
)
94+
@pytest.mark.parametrize(
95+
"redirect_port",
96+
[
97+
"",
98+
":0",
99+
":12345",
100+
],
101+
)
102+
@pytest.mark.parametrize(
103+
"dontwait",
104+
["false", "true"],
105+
)
106+
@pytest.mark.parametrize("reuse_port", ["true", "false"])
107+
def test_auth_callback_server_updates_localhost_redirect_uri_port_to_match_socket_port(
108+
monkeypatch,
109+
socket_host,
110+
socket_port,
111+
redirect_host,
112+
redirect_port,
113+
dontwait,
114+
reuse_port,
115+
) -> None:
116+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
117+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
118+
with AuthHttpServer(
119+
uri=f"http://{socket_host}{socket_port}/test_request",
120+
redirect_uri=f"http://{redirect_host}{redirect_port}/test_request",
121+
) as callback_server:
122+
assert callback_server._redirect_uri.port == callback_server.port
123+
124+
125+
@pytest.mark.parametrize(
126+
"socket_host",
127+
[
128+
"127.0.0.1",
129+
"localhost",
130+
],
131+
)
132+
@pytest.mark.parametrize(
133+
"socket_port",
134+
[
135+
"",
136+
":0",
137+
":12345",
138+
],
139+
)
140+
@pytest.mark.parametrize(
141+
"redirect_port",
142+
[
143+
"",
144+
":0",
145+
":12345",
146+
],
147+
)
148+
@pytest.mark.parametrize(
149+
"dontwait",
150+
["false", "true"],
151+
)
152+
@pytest.mark.parametrize("reuse_port", ["true", "false"])
153+
def test_auth_callback_server_does_not_updates_nonlocalhost_redirect_uri_port_to_match_socket_port(
154+
monkeypatch, socket_host, socket_port, redirect_port, dontwait, reuse_port
155+
) -> None:
156+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", reuse_port)
157+
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_MSG_DONTWAIT", dontwait)
158+
redirect_uri = f"http://not_localhost{redirect_port}/test_request"
159+
with AuthHttpServer(
160+
uri=f"http://{socket_host}{socket_port}/test_request", redirect_uri=redirect_uri
161+
) as callback_server:
162+
assert callback_server.redirect_uri == redirect_uri

test/unit/test_auth_oauth_auth_code.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55

66
import unittest.mock as mock
7-
from unittest.mock import patch
7+
from unittest.mock import PropertyMock, patch
88

99
import pytest
1010

@@ -209,3 +209,69 @@ def assert_initialized_correctly() -> None:
209209
assert_initialized_correctly()
210210
else:
211211
assert_initialized_correctly()
212+
213+
214+
@pytest.mark.parametrize("redirect_uri", ["https://redirect/uri"])
215+
@pytest.mark.parametrize("rtr_enabled", [True, False])
216+
def test_auth_oauth_auth_code_uses_redirect_uri(
217+
redirect_uri, rtr_enabled: bool, omit_oauth_urls_check
218+
):
219+
"""Test that the redirect URI is used correctly in the OAuth authorization code flow."""
220+
auth = AuthByOauthCode(
221+
"app",
222+
"clientId",
223+
"clientSecret",
224+
"auth_url",
225+
"tokenRequestUrl",
226+
redirect_uri,
227+
"scope",
228+
"host",
229+
pkce_enabled=False,
230+
enable_single_use_refresh_tokens=rtr_enabled,
231+
)
232+
233+
def fake_get_request_token_response(_, fields: dict[str, str]):
234+
if rtr_enabled:
235+
assert fields.get("enable_single_use_refresh_tokens") == "true"
236+
else:
237+
assert "enable_single_use_refresh_tokens" not in fields
238+
return ("access_token", "refresh_token")
239+
240+
with patch(
241+
"snowflake.connector.auth.AuthByOauthCode._construct_authorization_request",
242+
return_value="authorization_request",
243+
) as mock_construct_authorization_request:
244+
with patch(
245+
"snowflake.connector.auth.AuthByOauthCode._receive_authorization_callback",
246+
return_value=("code", auth._state),
247+
):
248+
with patch(
249+
"snowflake.connector.auth.AuthByOauthCode._ask_authorization_callback_from_user",
250+
return_value=("code", auth._state),
251+
):
252+
with patch(
253+
"snowflake.connector.auth.AuthByOauthCode._get_request_token_response",
254+
side_effect=fake_get_request_token_response,
255+
) as mock_get_request_token_response:
256+
with patch(
257+
"snowflake.connector.auth._http_server.AuthHttpServer.redirect_uri",
258+
return_value=redirect_uri,
259+
new_callable=PropertyMock,
260+
):
261+
auth.prepare(
262+
conn=None,
263+
authenticator=OAUTH_AUTHORIZATION_CODE,
264+
service_name=None,
265+
account="acc",
266+
user="user",
267+
)
268+
mock_construct_authorization_request.assert_called_once_with(
269+
redirect_uri
270+
)
271+
assert mock_get_request_token_response.call_count == 1
272+
assert (
273+
mock_get_request_token_response.call_args[0][1][
274+
"redirect_uri"
275+
]
276+
== redirect_uri
277+
)

test/unit/test_oauth_token.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def test_oauth_code_successful_flow(
142142
omit_oauth_urls_check,
143143
) -> None:
144144
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
145+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
145146

146147
wiremock_client.import_mapping(
147148
wiremock_oauth_authorization_code_dir / "successful_flow.json"
@@ -184,6 +185,7 @@ def test_oauth_code_invalid_state(
184185
omit_oauth_urls_check,
185186
) -> None:
186187
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
188+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
187189

188190
wiremock_client.import_mapping(
189191
wiremock_oauth_authorization_code_dir / "invalid_state_error.json"
@@ -219,6 +221,7 @@ def test_oauth_code_scope_error(
219221
monkeypatch,
220222
) -> None:
221223
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
224+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
222225

223226
wiremock_client.import_mapping(
224227
wiremock_oauth_authorization_code_dir / "invalid_scope_error.json"
@@ -255,6 +258,7 @@ def test_oauth_code_token_request_error(
255258
omit_oauth_urls_check,
256259
) -> None:
257260
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
261+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
258262

259263
with WiremockClient() as wiremock_client:
260264
wiremock_client.import_mapping(
@@ -293,6 +297,7 @@ def test_oauth_code_browser_timeout(
293297
omit_oauth_urls_check,
294298
) -> None:
295299
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
300+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
296301

297302
wiremock_client.import_mapping(
298303
wiremock_oauth_authorization_code_dir
@@ -334,6 +339,7 @@ def test_oauth_code_custom_urls(
334339
omit_oauth_urls_check,
335340
) -> None:
336341
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
342+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
337343

338344
wiremock_client.import_mapping(
339345
wiremock_oauth_authorization_code_dir / "external_idp_custom_urls.json"
@@ -377,6 +383,7 @@ def test_oauth_code_local_application_custom_urls_successful_flow(
377383
omit_oauth_urls_check,
378384
) -> None:
379385
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
386+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
380387

381388
wiremock_client.import_mapping(
382389
wiremock_oauth_authorization_code_dir
@@ -421,6 +428,7 @@ def test_oauth_code_successful_refresh_token_flow(
421428
omit_oauth_urls_check,
422429
) -> None:
423430
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
431+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
424432

425433
wiremock_client.import_mapping(
426434
wiremock_generic_mappings_dir / "snowflake_login_failed.json"
@@ -481,6 +489,7 @@ def test_oauth_code_expired_refresh_token_flow(
481489
omit_oauth_urls_check,
482490
) -> None:
483491
monkeypatch.setenv("SNOWFLAKE_AUTH_SOCKET_REUSE_PORT", "true")
492+
monkeypatch.setenv("SNOWFLAKE_OAUTH_SOCKET_PORT", "8009")
484493

485494
wiremock_client.import_mapping(
486495
wiremock_generic_mappings_dir / "snowflake_login_failed.json"
@@ -561,6 +570,10 @@ def test_client_creds_successful_flow(
561570
wiremock_generic_mappings_dir,
562571
monkeypatch,
563572
) -> None:
573+
monkeypatch.setenv(
574+
"SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port
575+
)
576+
564577
wiremock_client.import_mapping(
565578
wiremock_oauth_client_creds_dir / "successful_flow.json"
566579
)
@@ -595,6 +608,10 @@ def test_client_creds_token_request_error(
595608
wiremock_generic_mappings_dir,
596609
monkeypatch,
597610
) -> None:
611+
monkeypatch.setenv(
612+
"SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port
613+
)
614+
598615
wiremock_client.import_mapping(
599616
wiremock_oauth_client_creds_dir / "token_request_error.json"
600617
)
@@ -634,6 +651,10 @@ def test_client_creds_successful_refresh_token_flow(
634651
monkeypatch,
635652
temp_cache,
636653
) -> None:
654+
monkeypatch.setenv(
655+
"SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port
656+
)
657+
637658
wiremock_client.import_mapping(
638659
wiremock_generic_mappings_dir / "snowflake_login_failed.json"
639660
)
@@ -688,6 +709,10 @@ def test_client_creds_expired_refresh_token_flow(
688709
monkeypatch,
689710
temp_cache,
690711
) -> None:
712+
monkeypatch.setenv(
713+
"SNOWFLAKE_OAUTH_SOCKET_PORT", wiremock_client.wiremock_http_port
714+
)
715+
691716
wiremock_client.import_mapping(
692717
wiremock_generic_mappings_dir / "snowflake_login_failed.json"
693718
)

0 commit comments

Comments
 (0)