Skip to content

Commit b12ba84

Browse files
Clarify error messages detected during WIF training (#2469)
1 parent f890f48 commit b12ba84

13 files changed

+405
-36
lines changed

DESCRIPTION.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
2323
- Moved `OAUTH_TYPE` to `CLIENT_ENVIROMENT`.
2424
- Fix bug where PAT with external session authenticator was used while `external_session_id` was not provided in `SnowflakeRestful.fetch`
2525
- Added support for parameter `use_vectorized_scanner` in function `write_pandas`.
26+
- Fix unclear error messages in case of incorrect `authenticator` values.
27+
- Fix case-sensitivity of `Oauth` and `programmatic_access_token` authenticator values.
2628
- Relaxed `pyarrow` version constraint, versions >= 19 can now be used.
2729
- Populate type_code in ResultMetadata for interval types.
2830

src/snowflake/connector/connection.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,7 @@ def __open_connection(self):
12141214
raise TypeError("auth_class must be a child class of AuthByKeyPair")
12151215
# TODO: add telemetry for custom auth
12161216
self.auth_class = self.auth_class
1217+
# match authentivator - validation happens in __config
12171218
elif self._authenticator == DEFAULT_AUTHENTICATOR:
12181219
self.auth_class = AuthByDefault(
12191220
password=self._password,
@@ -1468,20 +1469,30 @@ def __config(self, **kwargs):
14681469
# type to be the same as the custom auth class
14691470
if self._auth_class:
14701471
self._authenticator = self._auth_class.type_.value
1471-
1472-
if self._authenticator:
1473-
# Only upper self._authenticator if it is a non-okta link
1472+
elif self._authenticator:
1473+
# Validate authenticator and convert it to uppercase if it is a non-okta link
14741474
auth_tmp = self._authenticator.upper()
1475-
if auth_tmp in [ # Non-okta authenticators
1475+
if auth_tmp in [
14761476
DEFAULT_AUTHENTICATOR,
14771477
EXTERNAL_BROWSER_AUTHENTICATOR,
14781478
KEY_PAIR_AUTHENTICATOR,
14791479
OAUTH_AUTHENTICATOR,
1480+
OAUTH_AUTHORIZATION_CODE,
1481+
OAUTH_CLIENT_CREDENTIALS,
14801482
USR_PWD_MFA_AUTHENTICATOR,
14811483
WORKLOAD_IDENTITY_AUTHENTICATOR,
1484+
PROGRAMMATIC_ACCESS_TOKEN,
14821485
PAT_WITH_EXTERNAL_SESSION,
14831486
]:
14841487
self._authenticator = auth_tmp
1488+
elif auth_tmp.startswith("HTTPS://"):
1489+
# okta authenticator link
1490+
pass
1491+
else:
1492+
raise ProgrammingError(
1493+
msg=f"Unknown authenticator: {self._authenticator}",
1494+
errno=ER_INVALID_VALUE,
1495+
)
14851496

14861497
# read OAuth token from
14871498
token_file_path = kwargs.get("token_file_path")

src/snowflake/connector/wif_util.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from botocore.awsrequest import AWSRequest
1414
from botocore.utils import InstanceMetadataRegionFetcher
1515

16-
from .errorcode import ER_WIF_CREDENTIALS_NOT_FOUND
16+
from .errorcode import ER_INVALID_WIF_SETTINGS, ER_WIF_CREDENTIALS_NOT_FOUND
1717
from .errors import ProgrammingError
1818
from .session_manager import SessionManager
1919

@@ -38,7 +38,13 @@ class AttestationProvider(Enum):
3838
@staticmethod
3939
def from_string(provider: str) -> AttestationProvider:
4040
"""Converts a string to a strongly-typed enum value of AttestationProvider."""
41-
return AttestationProvider[provider.upper()]
41+
try:
42+
return AttestationProvider[provider.upper()]
43+
except KeyError:
44+
raise ProgrammingError(
45+
msg=f"Unknown workload_identity_provider: '{provider}'. Expected one of: {', '.join(AttestationProvider.all_string_values())}",
46+
errno=ER_INVALID_WIF_SETTINGS,
47+
)
4248

4349
@staticmethod
4450
def all_string_values() -> list[str]:
@@ -65,7 +71,13 @@ def extract_iss_and_sub_without_signature_verification(jwt_str: str) -> tuple[st
6571
6672
Any errors during token parsing will be bubbled up. Missing 'iss' or 'sub' claims will also raise an error.
6773
"""
68-
claims = jwt.decode(jwt_str, options={"verify_signature": False})
74+
try:
75+
claims = jwt.decode(jwt_str, options={"verify_signature": False})
76+
except jwt.InvalidTokenError as e:
77+
raise ProgrammingError(
78+
msg=f"Invalid JWT token: {e}",
79+
errno=ER_INVALID_WIF_SETTINGS,
80+
)
6981

7082
if not ("iss" in claims and "sub" in claims):
7183
raise ProgrammingError(
@@ -179,14 +191,20 @@ def create_gcp_attestation(
179191
180192
If the application isn't running on GCP or no credentials were found, raises an error.
181193
"""
182-
res = session_manager.request(
183-
method="GET",
184-
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
185-
headers={
186-
"Metadata-Flavor": "Google",
187-
},
188-
)
189-
res.raise_for_status()
194+
try:
195+
res = session_manager.request(
196+
method="GET",
197+
url=f"http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/identity?audience={SNOWFLAKE_AUDIENCE}",
198+
headers={
199+
"Metadata-Flavor": "Google",
200+
},
201+
)
202+
res.raise_for_status()
203+
except Exception as e:
204+
raise ProgrammingError(
205+
msg=f"Error fetching GCP metadata: {e}. Ensure the application is running on GCP.",
206+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
207+
)
190208

191209
jwt_str = res.content.decode("utf-8")
192210
_, subject = extract_iss_and_sub_without_signature_verification(jwt_str)
@@ -230,12 +248,18 @@ def create_azure_attestation(
230248
if managed_identity_client_id:
231249
query_params += f"&client_id={managed_identity_client_id}"
232250

233-
res = session_manager.request(
234-
method="GET",
235-
url=f"{url_without_query_string}?{query_params}",
236-
headers=headers,
237-
)
238-
res.raise_for_status()
251+
try:
252+
res = session_manager.request(
253+
method="GET",
254+
url=f"{url_without_query_string}?{query_params}",
255+
headers=headers,
256+
)
257+
res.raise_for_status()
258+
except Exception as e:
259+
raise ProgrammingError(
260+
msg=f"Error fetching Azure metadata: {e}. Ensure the application is running on Azure.",
261+
errno=ER_WIF_CREDENTIALS_NOT_FOUND,
262+
)
239263

240264
jwt_str = res.json().get("access_token")
241265
if not jwt_str:

test/auth/authorization_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_base_connection_parameters(self) -> dict[str, Union[str, bool, int]]:
7979

8080
def get_key_pair_connection_parameters(self):
8181
config = self.basic_config.copy()
82-
config["authenticator"] = "KEY_PAIR_AUTHENTICATOR"
82+
config["authenticator"] = "SNOWFLAKE_JWT"
8383
config["user"] = _get_env_variable("SNOWFLAKE_AUTH_TEST_BROWSER_USER")
8484

8585
return config

test/unit/test_auth_keypair.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from unittest.mock import Mock, PropertyMock, patch
55

6+
import pytest
67
from cryptography.hazmat.backends import default_backend
78
from cryptography.hazmat.primitives import serialization
89
from cryptography.hazmat.primitives.asymmetric import rsa
@@ -36,7 +37,8 @@ def _mock_auth_key_pair_rest_response(url, headers, body, **kwargs):
3637
return _mock_auth_key_pair_rest_response
3738

3839

39-
def test_auth_keypair():
40+
@pytest.mark.parametrize("authenticator", ["SNOWFLAKE_JWT", "snowflake_jwt"])
41+
def test_auth_keypair(authenticator):
4042
"""Simple Key Pair test."""
4143
private_key_der, public_key_der_encoded = generate_key_pair(2048)
4244
application = "testapplication"
@@ -45,7 +47,7 @@ def test_auth_keypair():
4547
auth_instance = AuthByKeyPair(private_key=private_key_der)
4648
auth_instance._retry_ctx.set_start_time()
4749
auth_instance.handle_timeout(
48-
authenticator="SNOWFLAKE_JWT",
50+
authenticator=authenticator,
4951
service_name=None,
5052
account=account,
5153
user=user,

test/unit/test_auth_mfa.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from unittest import mock
22

3+
import pytest
4+
35
from snowflake.connector import connect
46

57

6-
def test_mfa_token_cache():
8+
@pytest.mark.parametrize(
9+
"authenticator", ["USERNAME_PASSWORD_MFA", "username_password_mfa"]
10+
)
11+
def test_mfa_token_cache(authenticator):
712
with mock.patch(
813
"snowflake.connector.network.SnowflakeRestful.fetch",
914
):
@@ -14,7 +19,7 @@ def test_mfa_token_cache():
1419
account="account",
1520
user="user",
1621
password="password",
17-
authenticator="username_password_mfa",
22+
authenticator=authenticator,
1823
client_store_temporary_credential=True,
1924
client_request_mfa_token=True,
2025
):
@@ -40,7 +45,7 @@ def test_mfa_token_cache():
4045
account="account",
4146
user="user",
4247
password="password",
43-
authenticator="username_password_mfa",
48+
authenticator=authenticator,
4449
client_store_temporary_credential=True,
4550
client_request_mfa_token=True,
4651
):

test/unit/test_auth_oauth.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from snowflake.connector.auth import AuthByOAuth
66
except ImportError:
77
from snowflake.connector.auth_oauth import AuthByOAuth
8+
import pytest
89

910

1011
def test_auth_oauth():
@@ -15,3 +16,38 @@ def test_auth_oauth():
1516
auth.update_body(body)
1617
assert body["data"]["TOKEN"] == token, body
1718
assert body["data"]["AUTHENTICATOR"] == "OAUTH", body
19+
20+
21+
@pytest.mark.parametrize("authenticator", ["oauth", "OAUTH"])
22+
def test_oauth_authenticator_is_case_insensitive(monkeypatch, authenticator):
23+
"""Test that oauth authenticator is case insensitive."""
24+
import snowflake.connector
25+
26+
def mock_post_request(self, url, headers, json_body, **kwargs):
27+
return {
28+
"success": True,
29+
"message": None,
30+
"data": {
31+
"token": "TOKEN",
32+
"masterToken": "MASTER_TOKEN",
33+
"idToken": None,
34+
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
35+
},
36+
}
37+
38+
monkeypatch.setattr(
39+
snowflake.connector.network.SnowflakeRestful, "_post_request", mock_post_request
40+
)
41+
42+
# Create connection with oauth authenticator - OAuth requires a token parameter
43+
conn = snowflake.connector.connect(
44+
user="testuser",
45+
account="testaccount",
46+
authenticator=authenticator,
47+
token="test_oauth_token", # OAuth authentication requires a token
48+
)
49+
50+
# Verify that the auth_class is an instance of AuthByOAuth
51+
assert isinstance(conn.auth_class, AuthByOAuth)
52+
53+
conn.close()

test/unit/test_auth_oauth_auth_code.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,50 @@ def assert_initialized_correctly() -> None:
211211
assert_initialized_correctly()
212212
else:
213213
assert_initialized_correctly()
214+
215+
216+
@pytest.mark.parametrize(
217+
"authenticator", ["OAUTH_AUTHORIZATION_CODE", "oauth_authorization_code"]
218+
)
219+
def test_oauth_authorization_code_authenticator_is_case_insensitive(
220+
monkeypatch, authenticator
221+
):
222+
"""Test that OAuth authorization code authenticator is case insensitive."""
223+
import snowflake.connector
224+
225+
def mock_post_request(self, url, headers, json_body, **kwargs):
226+
return {
227+
"success": True,
228+
"message": None,
229+
"data": {
230+
"token": "TOKEN",
231+
"masterToken": "MASTER_TOKEN",
232+
"idToken": None,
233+
"parameters": [{"name": "SERVICE_NAME", "value": "FAKE_SERVICE_NAME"}],
234+
},
235+
}
236+
237+
monkeypatch.setattr(
238+
snowflake.connector.network.SnowflakeRestful, "_post_request", mock_post_request
239+
)
240+
241+
# Mock the OAuth authorization flow to avoid opening browser and starting HTTP server
242+
def mock_request_tokens(self, **kwargs):
243+
# Simulate successful token retrieval
244+
return ("mock_access_token", "mock_refresh_token")
245+
246+
monkeypatch.setattr(AuthByOauthCode, "_request_tokens", mock_request_tokens)
247+
248+
# Create connection with OAuth authorization code authenticator
249+
conn = snowflake.connector.connect(
250+
user="testuser",
251+
account="testaccount",
252+
authenticator=authenticator,
253+
oauth_client_id="test_client_id",
254+
oauth_client_secret="test_client_secret",
255+
)
256+
257+
# Verify that the auth_class is an instance of AuthByOauthCode
258+
assert isinstance(conn.auth_class, AuthByOauthCode)
259+
260+
conn.close()

0 commit comments

Comments
 (0)