Skip to content

Commit 9a85c62

Browse files
committed
SNOW-1572226: implement all authentication methods (#2064)
1 parent e687be4 commit 9a85c62

22 files changed

+3256
-80
lines changed

src/snowflake/connector/aio/_connection.py

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,20 @@
2727
)
2828

2929
from .._query_context_cache import QueryContextCache
30-
from ..auth import AuthByIdToken
31-
from ..compat import quote, urlencode
30+
from ..compat import IS_LINUX, quote, urlencode
3231
from ..config_manager import CONFIG_MANAGER, _get_default_connection_params
3332
from ..connection import DEFAULT_CONFIGURATION
3433
from ..connection import SnowflakeConnection as SnowflakeConnectionSync
34+
from ..connection import _get_private_bytes_from_file
3535
from ..connection_diagnostic import ConnectionDiagnostic
3636
from ..constants import (
3737
ENV_VAR_PARTNER,
3838
PARAMETER_AUTOCOMMIT,
3939
PARAMETER_CLIENT_PREFETCH_THREADS,
40+
PARAMETER_CLIENT_REQUEST_MFA_TOKEN,
4041
PARAMETER_CLIENT_SESSION_KEEP_ALIVE,
4142
PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY,
43+
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL,
4244
PARAMETER_CLIENT_TELEMETRY_ENABLED,
4345
PARAMETER_CLIENT_VALIDATE_DEFAULT_PARAMETERS,
4446
PARAMETER_ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1,
@@ -53,15 +55,34 @@
5355
ER_FAILED_TO_CONNECT_TO_DB,
5456
ER_INVALID_VALUE,
5557
)
56-
from ..network import DEFAULT_AUTHENTICATOR, REQUEST_ID, ReauthenticationRequest
58+
from ..network import (
59+
DEFAULT_AUTHENTICATOR,
60+
EXTERNAL_BROWSER_AUTHENTICATOR,
61+
KEY_PAIR_AUTHENTICATOR,
62+
OAUTH_AUTHENTICATOR,
63+
REQUEST_ID,
64+
USR_PWD_MFA_AUTHENTICATOR,
65+
ReauthenticationRequest,
66+
)
5767
from ..sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED
5868
from ..telemetry import TelemetryData, TelemetryField
5969
from ..time_util import get_time_millis
6070
from ..util_text import split_statements
6171
from ._cursor import SnowflakeCursor
6272
from ._network import SnowflakeRestful
6373
from ._time_util import HeartBeatTimer
64-
from .auth import Auth, AuthByDefault, AuthByPlugin
74+
from .auth import (
75+
FIRST_PARTY_AUTHENTICATORS,
76+
Auth,
77+
AuthByDefault,
78+
AuthByIdToken,
79+
AuthByKeyPair,
80+
AuthByOAuth,
81+
AuthByOkta,
82+
AuthByPlugin,
83+
AuthByUsrPwdMfa,
84+
AuthByWebBrowser,
85+
)
6586

6687
logger = getLogger(__name__)
6788

@@ -196,7 +217,6 @@ async def __open_connection(self):
196217
heartbeat_ret = await auth._rest._heartbeat()
197218
logger.debug(heartbeat_ret)
198219
if not heartbeat_ret or not heartbeat_ret.get("success"):
199-
# TODO: errorhandler could be async?
200220
Error.errorhandler_wrapper(
201221
self,
202222
None,
@@ -211,20 +231,94 @@ async def __open_connection(self):
211231

212232
else:
213233
if self.auth_class is not None:
214-
raise NotImplementedError(
215-
"asyncio support for auth_class is not supported"
216-
)
234+
if type(
235+
self.auth_class
236+
) not in FIRST_PARTY_AUTHENTICATORS and not issubclass(
237+
type(self.auth_class), AuthByKeyPair
238+
):
239+
raise TypeError("auth_class must be a child class of AuthByKeyPair")
240+
self.auth_class = self.auth_class
217241
elif self._authenticator == DEFAULT_AUTHENTICATOR:
218242
self.auth_class = AuthByDefault(
219243
password=self._password,
220244
timeout=self.login_timeout,
221245
backoff_generator=self._backoff_generator,
222246
)
247+
elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR:
248+
self._session_parameters[
249+
PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL
250+
] = (self._client_store_temporary_credential if IS_LINUX else True)
251+
auth.read_temporary_credentials(
252+
self.host,
253+
self.user,
254+
self._session_parameters,
255+
)
256+
# Depending on whether self._rest.id_token is available we do different
257+
# auth_instance
258+
if self._rest.id_token is None:
259+
self.auth_class = AuthByWebBrowser(
260+
application=self.application,
261+
protocol=self._protocol,
262+
host=self.host,
263+
port=self.port,
264+
timeout=self.login_timeout,
265+
backoff_generator=self._backoff_generator,
266+
)
267+
else:
268+
self.auth_class = AuthByIdToken(
269+
id_token=self._rest.id_token,
270+
application=self.application,
271+
protocol=self._protocol,
272+
host=self.host,
273+
port=self.port,
274+
timeout=self.login_timeout,
275+
backoff_generator=self._backoff_generator,
276+
)
277+
278+
elif self._authenticator == KEY_PAIR_AUTHENTICATOR:
279+
private_key = self._private_key
280+
281+
if self._private_key_file:
282+
private_key = _get_private_bytes_from_file(
283+
self._private_key_file,
284+
self._private_key_file_pwd,
285+
)
286+
287+
self.auth_class = AuthByKeyPair(
288+
private_key=private_key,
289+
timeout=self.login_timeout,
290+
backoff_generator=self._backoff_generator,
291+
)
292+
elif self._authenticator == OAUTH_AUTHENTICATOR:
293+
self.auth_class = AuthByOAuth(
294+
oauth_token=self._token,
295+
timeout=self.login_timeout,
296+
backoff_generator=self._backoff_generator,
297+
)
298+
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
299+
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
300+
self._client_request_mfa_token if IS_LINUX else True
301+
)
302+
if self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN]:
303+
auth.read_temporary_credentials(
304+
self.host,
305+
self.user,
306+
self._session_parameters,
307+
)
308+
self.auth_class = AuthByUsrPwdMfa(
309+
password=self._password,
310+
mfa_token=self.rest.mfa_token,
311+
timeout=self.login_timeout,
312+
backoff_generator=self._backoff_generator,
313+
)
223314
else:
224-
raise NotImplementedError(
225-
f"asyncio support for authenticator is not supported {self._authenticator}"
315+
# okta URL, e.g., https://<account>.okta.com/
316+
self.auth_class = AuthByOkta(
317+
application=self.application,
318+
timeout=self.login_timeout,
319+
backoff_generator=self._backoff_generator,
226320
)
227-
# TODO: asyncio support for other authenticators
321+
228322
await self.authenticate_with_retry(self.auth_class)
229323

230324
self._password = None # ensure password won't persist

src/snowflake/connector/aio/auth/__init__.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,38 @@
44

55
from __future__ import annotations
66

7+
from ...auth.by_plugin import AuthType
78
from ._auth import Auth
89
from ._by_plugin import AuthByPlugin
910
from ._default import AuthByDefault
11+
from ._idtoken import AuthByIdToken
12+
from ._keypair import AuthByKeyPair
13+
from ._oauth import AuthByOAuth
14+
from ._okta import AuthByOkta
15+
from ._usrpwdmfa import AuthByUsrPwdMfa
16+
from ._webbrowser import AuthByWebBrowser
17+
18+
FIRST_PARTY_AUTHENTICATORS = frozenset(
19+
(
20+
AuthByDefault,
21+
AuthByKeyPair,
22+
AuthByOAuth,
23+
AuthByOkta,
24+
AuthByUsrPwdMfa,
25+
AuthByWebBrowser,
26+
AuthByIdToken,
27+
)
28+
)
1029

1130
__all__ = [
12-
AuthByDefault,
13-
Auth,
14-
AuthByPlugin,
31+
"AuthByPlugin",
32+
"AuthByDefault",
33+
"AuthByKeyPair",
34+
"AuthByOAuth",
35+
"AuthByOkta",
36+
"AuthByUsrPwdMfa",
37+
"AuthByWebBrowser",
38+
"Auth",
39+
"AuthType",
40+
"FIRST_PARTY_AUTHENTICATORS",
1541
]

src/snowflake/connector/aio/auth/_auth.py

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44

55
from __future__ import annotations
66

7+
import asyncio
78
import copy
89
import json
910
import logging
1011
import uuid
12+
from datetime import datetime, timezone
1113
from typing import TYPE_CHECKING, Any, Callable
1214

1315
from ...auth import Auth as AuthSync
14-
from ...auth._auth import ID_TOKEN, delete_temporary_credential
16+
from ...auth._auth import ID_TOKEN, MFA_TOKEN, delete_temporary_credential
1517
from ...compat import urlencode
1618
from ...constants import (
1719
HTTP_HEADER_ACCEPT,
@@ -62,9 +64,10 @@ async def authenticate(
6264
timeout: int | None = None,
6365
) -> dict[str, str | int | bool]:
6466
if mfa_callback or password_callback:
65-
# TODO: what's the usage of callback here and whether callback should be async?
67+
# check SNOW-1707210 for mfa_callback and password_callback support
6668
raise NotImplementedError(
67-
"mfa_callback or password_callback not supported for asyncio"
69+
"mfa_callback or password_callback is not supported in asyncio connector, please open a feature"
70+
" request issue in github: https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose"
6871
)
6972
logger.debug("authenticate")
7073

@@ -148,7 +151,6 @@ async def authenticate(
148151
json.dumps(body),
149152
socket_timeout=auth_instance._socket_timeout,
150153
)
151-
# TODO: encapsulate error handling logic to be shared between sync and async
152154
except ForbiddenError as err:
153155
# HTTP 403
154156
raise err.__class__(
@@ -181,7 +183,65 @@ async def authenticate(
181183
"EXT_AUTHN_DUO_ALL",
182184
"EXT_AUTHN_DUO_PUSH_N_PASSCODE",
183185
):
184-
raise NotImplementedError("asyncio MFA not supported")
186+
body["inFlightCtx"] = ret["data"].get("inFlightCtx")
187+
body["data"]["EXT_AUTHN_DUO_METHOD"] = "push"
188+
self.ret = {"message": "Timeout", "data": {}}
189+
190+
async def post_request_wrapper(self, url, headers, body) -> None:
191+
# get the MFA response
192+
self.ret = await self._rest._post_request(
193+
url,
194+
headers,
195+
body,
196+
socket_timeout=auth_instance._socket_timeout,
197+
)
198+
199+
# send new request to wait until MFA is approved
200+
try:
201+
await asyncio.wait_for(
202+
post_request_wrapper(self, url, headers, json.dumps(body)),
203+
timeout=timeout,
204+
)
205+
except asyncio.TimeoutError:
206+
logger.debug("get the MFA response timed out")
207+
208+
ret = self.ret
209+
if (
210+
ret
211+
and ret["data"]
212+
and ret["data"].get("nextAction") == "EXT_AUTHN_SUCCESS"
213+
):
214+
body = copy.deepcopy(body_template)
215+
body["inFlightCtx"] = ret["data"].get("inFlightCtx")
216+
# final request to get tokens
217+
ret = await self._rest._post_request(
218+
url,
219+
headers,
220+
json.dumps(body),
221+
socket_timeout=auth_instance._socket_timeout,
222+
)
223+
elif not ret or not ret["data"] or not ret["data"].get("token"):
224+
# not token is returned.
225+
Error.errorhandler_wrapper(
226+
self._rest._connection,
227+
None,
228+
DatabaseError,
229+
{
230+
"msg": (
231+
"Failed to connect to DB. MFA "
232+
"authentication failed: {"
233+
"host}:{port}. {message}"
234+
).format(
235+
host=self._rest._host,
236+
port=self._rest._port,
237+
message=ret["message"],
238+
),
239+
"errno": ER_FAILED_TO_CONNECT_TO_DB,
240+
"sqlstate": SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
241+
},
242+
)
243+
return session_parameters # required for unit test
244+
185245
elif ret["data"] and ret["data"].get("nextAction") == "PWD_CHANGE":
186246
if callable(password_callback):
187247
body = copy.deepcopy(body_template)
@@ -216,23 +276,20 @@ async def authenticate(
216276
sqlstate=SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
217277
)
218278
)
219-
# TODO: error handling for AuthByKeyPairAsync and AuthByUsrPwdMfaAsync
220-
# from . import AuthByKeyPair
221-
#
222-
# if isinstance(auth_instance, AuthByKeyPair):
223-
# logger.debug(
224-
# "JWT Token authentication failed. "
225-
# "Token expires at: %s. "
226-
# "Current Time: %s",
227-
# str(auth_instance._jwt_token_exp),
228-
# str(datetime.now(timezone.utc).replace(tzinfo=None)),
229-
# )
230-
# from . import AuthByUsrPwdMfa
231-
#
232-
# if isinstance(auth_instance, AuthByUsrPwdMfa):
233-
# delete_temporary_credential(self._rest._host, user, MFA_TOKEN)
234-
# TODO: can errorhandler of a connection be async? should we support both sync and async
235-
# users could perform async ops in the error handling
279+
from . import AuthByKeyPair
280+
281+
if isinstance(auth_instance, AuthByKeyPair):
282+
logger.debug(
283+
"JWT Token authentication failed. "
284+
"Token expires at: %s. "
285+
"Current Time: %s",
286+
str(auth_instance._jwt_token_exp),
287+
str(datetime.now(timezone.utc).replace(tzinfo=None)),
288+
)
289+
from . import AuthByUsrPwdMfa
290+
291+
if isinstance(auth_instance, AuthByUsrPwdMfa):
292+
delete_temporary_credential(self._rest._host, user, MFA_TOKEN)
236293
Error.errorhandler_wrapper(
237294
self._rest._connection,
238295
None,

src/snowflake/connector/aio/auth/_by_plugin.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,28 @@
77
import asyncio
88
import logging
99
from abc import abstractmethod
10-
from typing import Any
10+
from typing import TYPE_CHECKING, Any, Iterator
1111

12-
from ... import DatabaseError, Error, OperationalError, SnowflakeConnection
12+
from ... import DatabaseError, Error, OperationalError
1313
from ...auth import AuthByPlugin as AuthByPluginSync
1414
from ...errorcode import ER_FAILED_TO_CONNECT_TO_DB
1515
from ...sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
1616

17+
if TYPE_CHECKING:
18+
from .. import SnowflakeConnection
19+
1720
logger = logging.getLogger(__name__)
1821

1922

2023
class AuthByPlugin(AuthByPluginSync):
24+
def __init__(
25+
self,
26+
timeout: int | None = None,
27+
backoff_generator: Iterator | None = None,
28+
**kwargs,
29+
) -> None:
30+
super().__init__(timeout, backoff_generator, **kwargs)
31+
2132
@abstractmethod
2233
async def prepare(
2334
self,

0 commit comments

Comments
 (0)