Skip to content

Commit 12f7d03

Browse files
SNOW-733835 fix AuthByIdToken expired token auth (#1444)
1 parent 28fa609 commit 12f7d03

File tree

5 files changed

+100
-9
lines changed

5 files changed

+100
-9
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1212

1313
- Improved the robustness of OCSP response caching to handle errors in cases of serialization and deserialization.
1414
- Fixed a bug where `AuthByKeyPair.handle_timeout` should pass keyword arguments instead of positional arguments when calling `AuthByKeyPair.prepare`. PR #1440 (@emilhe)
15+
- Fixed a bug where MFA token caching would refuse to work until restarted instead of reauthenticating
1516

1617
- v3.0.0(January 26, 2023)
1718

src/snowflake/connector/auth/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._auth import Auth, get_public_key_fingerprint, get_token_from_private_key
88
from .by_plugin import AuthByPlugin, AuthType
99
from .default import AuthByDefault
10+
from .idtoken import AuthByIdToken
1011
from .keypair import AuthByKeyPair
1112
from .oauth import AuthByOAuth
1213
from .okta import AuthByOkta
@@ -21,6 +22,7 @@
2122
AuthByOkta,
2223
AuthByUsrPwdMfa,
2324
AuthByWebBrowser,
25+
AuthByIdToken,
2426
)
2527
)
2628

src/snowflake/connector/auth/idtoken.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55

66
from __future__ import annotations
77

8-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
99

1010
from ..network import ID_TOKEN_AUTHENTICATOR
1111
from .by_plugin import AuthByPlugin, AuthType
12+
from .webbrowser import AuthByWebBrowser
13+
14+
if TYPE_CHECKING:
15+
from ..connection import SnowflakeConnection
1216

1317

1418
class AuthByIdToken(AuthByPlugin):
@@ -25,19 +29,43 @@ def type_(self) -> AuthType:
2529
def assertion_content(self) -> str:
2630
return self._id_token
2731

28-
def __init__(self, id_token: str) -> None:
32+
def __init__(
33+
self,
34+
id_token: str,
35+
application: str,
36+
protocol: str | None,
37+
host: str | None,
38+
port: str | None,
39+
) -> None:
2940
"""Initialized an instance with an IdToken."""
3041
super().__init__()
3142
self._id_token: str | None = id_token
43+
self._application = application
44+
self._protocol = protocol
45+
self._host = host
46+
self._port = port
3247

3348
def reset_secrets(self) -> None:
3449
self._id_token = None
3550

3651
def prepare(self, **kwargs: Any) -> None:
3752
pass
3853

39-
def reauthenticate(self, **kwargs: Any) -> dict[str, bool]:
40-
return {"success": False}
54+
def reauthenticate(
55+
self,
56+
*,
57+
conn: SnowflakeConnection,
58+
**kwargs: Any,
59+
) -> dict[str, bool]:
60+
conn.auth_class = AuthByWebBrowser(
61+
application=self._application,
62+
protocol=self._protocol,
63+
host=self._host,
64+
port=self._port,
65+
)
66+
conn._authenticate(conn.auth_class)
67+
conn._auth_class.reset_secrets()
68+
return {"success": True}
4169

4270
def update_body(self, body: dict[Any, Any]) -> None:
4371
"""Idtoken needs the authenticator and token attributes set."""

src/snowflake/connector/connection.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def region(self) -> str | None:
351351
warnings.warn(
352352
"Region has been deprecated and will be removed in the near future",
353353
PendingDeprecationWarning,
354+
# Raise warning from where this property was called from
355+
stacklevel=2,
354356
)
355357
return self._region
356358

@@ -804,7 +806,13 @@ def __open_connection(self):
804806
port=self.port,
805807
)
806808
else:
807-
self.auth_class = AuthByIdToken(id_token=self._rest.id_token)
809+
self.auth_class = AuthByIdToken(
810+
id_token=self._rest.id_token,
811+
application=self.application,
812+
protocol=self._protocol,
813+
host=self.host,
814+
port=self.port,
815+
)
808816

809817
elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
810818
self.auth_class = AuthByKeyPair(private_key=self._private_key)
@@ -866,7 +874,9 @@ def __config(self, **kwargs):
866874
warnings.warn(
867875
"'{}' is an unknown connection parameter{}".format(
868876
name, f", did you mean '{guess}'?" if guess else ""
869-
)
877+
),
878+
# Raise warning from where class was initiated
879+
stacklevel=4,
870880
)
871881
elif not isinstance(value, DEFAULT_CONFIGURATION[name][1]):
872882
accepted_types = DEFAULT_CONFIGURATION[name][1]
@@ -879,7 +889,9 @@ def __config(self, **kwargs):
879889
if isinstance(accepted_types, tuple)
880890
else accepted_types.__name__,
881891
type(value).__name__,
882-
)
892+
),
893+
# Raise warning from where class was initiated
894+
stacklevel=4,
883895
)
884896
setattr(self, "_" + name, value)
885897

@@ -1088,7 +1100,12 @@ def authenticate_with_retry(self, auth_instance) -> None:
10881100
except ReauthenticationRequest as ex:
10891101
# cached id_token expiration error, we have cleaned id_token and try to authenticate again
10901102
logger.debug("ID token expired. Reauthenticating...: %s", ex)
1091-
self._authenticate(auth_instance)
1103+
if isinstance(auth_instance, AuthByIdToken):
1104+
# Note: SNOW-733835 IDToken auth needs to authenticate through
1105+
# SSO if it has expired
1106+
self._reauthenticate()
1107+
else:
1108+
self._authenticate(auth_instance)
10921109

10931110
def _authenticate(self, auth_instance: AuthByPlugin):
10941111
auth_instance.prepare(

test/unit/test_auth_webbrowser.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@
55

66
from __future__ import annotations
77

8+
from unittest import mock
89
from unittest.mock import MagicMock, Mock, PropertyMock, patch
910

1011
import pytest
1112

13+
from snowflake.connector import SnowflakeConnection
1214
from snowflake.connector.constants import OCSPMode
1315
from snowflake.connector.description import CLIENT_NAME, CLIENT_VERSION
14-
from snowflake.connector.network import EXTERNAL_BROWSER_AUTHENTICATOR, SnowflakeRestful
16+
from snowflake.connector.network import (
17+
EXTERNAL_BROWSER_AUTHENTICATOR,
18+
ReauthenticationRequest,
19+
SnowflakeRestful,
20+
)
1521

1622
try: # pragma: no cover
1723
from snowflake.connector.auth import AuthByWebBrowser
@@ -275,3 +281,40 @@ def post_request(url, headers, body, **kwargs):
275281
rest._post_request = post_request
276282
connection._rest = rest
277283
return rest
284+
285+
286+
def test_idtoken_reauth():
287+
"""This test makes sure that AuthByIdToken reverts to AuthByWebBrowser.
288+
289+
This happens when the initial connection fails. Such as when the saved ID
290+
token has expired.
291+
"""
292+
from snowflake.connector.auth.idtoken import AuthByIdToken
293+
294+
auth_inst = AuthByIdToken(
295+
id_token="token",
296+
application="application",
297+
protocol="protocol",
298+
host="host",
299+
port="port",
300+
)
301+
302+
# We'll use this Exception to make sure AuthByWebBrowser authentication
303+
# flow is called as expected
304+
class StopExecuting(Exception):
305+
pass
306+
307+
with mock.patch(
308+
"snowflake.connector.auth.idtoken.AuthByIdToken.prepare",
309+
side_effect=ReauthenticationRequest(Exception()),
310+
):
311+
with mock.patch(
312+
"snowflake.connector.auth.webbrowser.AuthByWebBrowser.prepare",
313+
side_effect=StopExecuting(),
314+
):
315+
with pytest.raises(StopExecuting):
316+
SnowflakeConnection(
317+
user="user",
318+
account="account",
319+
auth_class=auth_inst,
320+
)

0 commit comments

Comments
 (0)