Skip to content

Commit 8404bbf

Browse files
[async] Apply #2469; enhance OAUTH async tests
1 parent e2cdea3 commit 8404bbf

11 files changed

+614
-50
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
OAUTH_AUTHENTICATOR,
6262
OAUTH_AUTHORIZATION_CODE,
6363
OAUTH_CLIENT_CREDENTIALS,
64+
PAT_WITH_EXTERNAL_SESSION,
6465
PROGRAMMATIC_ACCESS_TOKEN,
6566
REQUEST_ID,
6667
USR_PWD_MFA_AUTHENTICATOR,
@@ -247,7 +248,7 @@ async def __open_connection(self):
247248
self._validate_client_prefetch_threads()
248249
)
249250

250-
# Setup authenticator
251+
# Setup authenticator - validation happens in __config
251252
auth = Auth(self.rest)
252253

253254
if self._session_token and self._master_token:
@@ -380,6 +381,12 @@ async def __open_connection(self):
380381
)
381382
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
382383
self.auth_class = AuthByPAT(self._token)
384+
elif self._authenticator == PAT_WITH_EXTERNAL_SESSION:
385+
# TODO: SNOW-2344581: add support for PAT with external session ID for async connection
386+
raise ProgrammingError(
387+
msg="PAT with external session ID is not supported for async connection.",
388+
errno=ER_INVALID_VALUE,
389+
)
383390
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
384391
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
385392
self._client_request_mfa_token if IS_LINUX else True

src/snowflake/connector/aio/_wif_util.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,23 @@ async def create_gcp_attestation(
8888
8989
If the application isn't running on GCP or no credentials were found, raises an error.
9090
"""
91-
res = await session_manager.request(
92-
method="GET",
93-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
94-
headers={
95-
"Metadata-Flavor": "Google",
96-
},
97-
)
91+
try:
92+
res = await session_manager.request(
93+
method="GET",
94+
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
95+
headers={
96+
"Metadata-Flavor": "Google",
97+
},
98+
)
99+
100+
content = await res.content.read()
101+
jwt_str = content.decode("utf-8")
102+
except Exception as e:
103+
raise ProgrammingError(
104+
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
105+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
106+
)
98107

99-
content = await res.content.read()
100-
jwt_str = content.decode("utf-8")
101108
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
102109
return WorkloadIdentityAttestation(
103110
AttestationProvider.GCP, jwt_str, {"sub": subject}
@@ -139,15 +146,22 @@ async def create_azure_attestation(
139146
if managed_identity_client_id:
140147
query_params += f"&client_id={managed_identity_client_id}"
141148

142-
res = await session_manager.request(
143-
method="GET",
144-
url=f"{url_without_query_string}?{query_params}",
145-
headers=headers,
146-
)
149+
try:
150+
res = await session_manager.request(
151+
method="GET",
152+
url=f"{url_without_query_string}?{query_params}",
153+
headers=headers,
154+
)
155+
156+
content = await res.content.read()
157+
response_text = content.decode("utf-8")
158+
response_data = json.loads(response_text)
159+
except Exception as e:
160+
raise ProgrammingError(
161+
msg=f"Error fetching Azure metadata: {e}. Ensure the application is running on Azure.",
162+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
163+
)
147164

148-
content = await res.content.read()
149-
response_text = content.decode("utf-8")
150-
response_data = json.loads(response_text)
151165
jwt_str = response_data.get("access_token")
152166
if not jwt_str:
153167
raise ProgrammingError(

test/unit/aio/test_auth_keypair_async.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from test.unit.aio.mock_utils import mock_connection
99
from unittest.mock import Mock, PropertyMock, patch
1010

11+
import pytest
1112
from cryptography.hazmat.backends import default_backend
1213
from cryptography.hazmat.primitives import serialization
1314
from cryptography.hazmat.primitives.asymmetric import rsa
@@ -34,7 +35,8 @@ async def _mock_auth_key_pair_rest_response(url, headers, body, **kwargs):
3435
return _mock_auth_key_pair_rest_response
3536

3637

37-
async def test_auth_keypair():
38+
@pytest.mark.parametrize("authenticator", ["SNOWFLAKE_JWT", "snowflake_jwt"])
39+
async def test_auth_keypair(authenticator):
3840
"""Simple Key Pair test."""
3941
private_key_der, public_key_der_encoded = generate_key_pair(2048)
4042
application = "testapplication"
@@ -43,7 +45,7 @@ async def test_auth_keypair():
4345
auth_instance = AuthByKeyPair(private_key=private_key_der)
4446
auth_instance._retry_ctx.set_start_time()
4547
await auth_instance.handle_timeout(
46-
authenticator="SNOWFLAKE_JWT",
48+
authenticator=authenticator,
4749
service_name=None,
4850
account=account,
4951
user=user,

test/unit/aio/test_auth_mfa_async.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44

55
from unittest import mock
66

7+
import pytest
8+
79
from snowflake.connector.aio import SnowflakeConnection
810

911

10-
async def test_mfa_token_cache():
12+
@pytest.mark.parametrize(
13+
"authenticator", ["USERNAME_PASSWORD_MFA", "username_password_mfa"]
14+
)
15+
async def test_mfa_token_cache(authenticator):
1116
with mock.patch(
1217
"snowflake.connector.aio._network.SnowflakeRestful.fetch",
1318
):
@@ -18,7 +23,7 @@ async def test_mfa_token_cache():
1823
account="account",
1924
user="user",
2025
password="password",
21-
authenticator="username_password_mfa",
26+
authenticator=authenticator,
2227
client_store_temporary_credential=True,
2328
client_request_mfa_token=True,
2429
):
@@ -44,7 +49,7 @@ async def test_mfa_token_cache():
4449
account="account",
4550
user="user",
4651
password="password",
47-
authenticator="username_password_mfa",
52+
authenticator=authenticator,
4853
client_store_temporary_credential=True,
4954
client_request_mfa_token=True,
5055
):

test/unit/aio/test_auth_oauth_async.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from __future__ import annotations
77

8+
import pytest
9+
810
from snowflake.connector.aio.auth import AuthByOAuth
911

1012

@@ -18,6 +20,44 @@ async def test_auth_oauth():
1820
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
1921

2022

23+
@pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"])
24+
async def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator):
25+
"""Test that oauth authenticator is case insensitive."""
26+
import snowflake.connector.aio
27+
28+
async def mock_post_request(self, url, headers, json_body, **kwargs):
29+
return {
30+
"success": True,
31+
"message": None,
32+
"data": {
33+
"token": "TOKEN",
34+
"masterToken": "MASTER_TOKEN",
35+
"idToken": None,
36+
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
37+
},
38+
}
39+
40+
monkeypatch.setattr(
41+
snowflake.connector.aio._network.SnowflakeRestful,
42+
"_post_request",
43+
mock_post_request,
44+
)
45+
46+
# Create connection with oauth authenticator - OAuth requires a token parameter
47+
conn = snowflake.connector.aio.SnowflakeConnection(
48+
user="testuser",
49+
account="testaccount",
50+
authenticator=authenticator,
51+
token="test_oauth_token", # OAuth authentication requires a token
52+
)
53+
await conn.connect()
54+
55+
# Verify that the auth_class is an instance of AuthByOAuth
56+
assert isinstance(conn.auth_class, AuthByOAuth)
57+
58+
await conn.close()
59+
60+
2161
def test_mro():
2262
"""Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
2363
from snowflake.connector.aio.auth import AuthByPlugin as AuthByPluginAsync

0 commit comments

Comments
 (0)