27
27
)
28
28
29
29
from .._query_context_cache import QueryContextCache
30
- from ..auth import AuthByIdToken
31
- from ..compat import quote , urlencode
30
+ from ..compat import IS_LINUX , quote , urlencode
32
31
from ..config_manager import CONFIG_MANAGER , _get_default_connection_params
33
32
from ..connection import DEFAULT_CONFIGURATION
34
33
from ..connection import SnowflakeConnection as SnowflakeConnectionSync
34
+ from ..connection import _get_private_bytes_from_file
35
35
from ..connection_diagnostic import ConnectionDiagnostic
36
36
from ..constants import (
37
37
ENV_VAR_PARTNER ,
38
38
PARAMETER_AUTOCOMMIT ,
39
39
PARAMETER_CLIENT_PREFETCH_THREADS ,
40
+ PARAMETER_CLIENT_REQUEST_MFA_TOKEN ,
40
41
PARAMETER_CLIENT_SESSION_KEEP_ALIVE ,
41
42
PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY ,
43
+ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL ,
42
44
PARAMETER_CLIENT_TELEMETRY_ENABLED ,
43
45
PARAMETER_CLIENT_VALIDATE_DEFAULT_PARAMETERS ,
44
46
PARAMETER_ENABLE_STAGE_S3_PRIVATELINK_FOR_US_EAST_1 ,
53
55
ER_FAILED_TO_CONNECT_TO_DB ,
54
56
ER_INVALID_VALUE ,
55
57
)
56
- from ..network import DEFAULT_AUTHENTICATOR , REQUEST_ID , ReauthenticationRequest
58
+ from ..network import (
59
+ DEFAULT_AUTHENTICATOR ,
60
+ EXTERNAL_BROWSER_AUTHENTICATOR ,
61
+ KEY_PAIR_AUTHENTICATOR ,
62
+ OAUTH_AUTHENTICATOR ,
63
+ REQUEST_ID ,
64
+ USR_PWD_MFA_AUTHENTICATOR ,
65
+ ReauthenticationRequest ,
66
+ )
57
67
from ..sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS , SQLSTATE_FEATURE_NOT_SUPPORTED
58
68
from ..telemetry import TelemetryData , TelemetryField
59
69
from ..time_util import get_time_millis
60
70
from ..util_text import split_statements
61
71
from ._cursor import SnowflakeCursor
62
72
from ._network import SnowflakeRestful
63
73
from ._time_util import HeartBeatTimer
64
- from .auth import Auth , AuthByDefault , AuthByPlugin
74
+ from .auth import (
75
+ FIRST_PARTY_AUTHENTICATORS ,
76
+ Auth ,
77
+ AuthByDefault ,
78
+ AuthByIdToken ,
79
+ AuthByKeyPair ,
80
+ AuthByOAuth ,
81
+ AuthByOkta ,
82
+ AuthByPlugin ,
83
+ AuthByUsrPwdMfa ,
84
+ AuthByWebBrowser ,
85
+ )
65
86
66
87
logger = getLogger (__name__ )
67
88
@@ -196,7 +217,6 @@ async def __open_connection(self):
196
217
heartbeat_ret = await auth ._rest ._heartbeat ()
197
218
logger .debug (heartbeat_ret )
198
219
if not heartbeat_ret or not heartbeat_ret .get ("success" ):
199
- # TODO: errorhandler could be async?
200
220
Error .errorhandler_wrapper (
201
221
self ,
202
222
None ,
@@ -211,20 +231,94 @@ async def __open_connection(self):
211
231
212
232
else :
213
233
if self .auth_class is not None :
214
- raise NotImplementedError (
215
- "asyncio support for auth_class is not supported"
216
- )
234
+ if type (
235
+ self .auth_class
236
+ ) not in FIRST_PARTY_AUTHENTICATORS and not issubclass (
237
+ type (self .auth_class ), AuthByKeyPair
238
+ ):
239
+ raise TypeError ("auth_class must be a child class of AuthByKeyPair" )
240
+ self .auth_class = self .auth_class
217
241
elif self ._authenticator == DEFAULT_AUTHENTICATOR :
218
242
self .auth_class = AuthByDefault (
219
243
password = self ._password ,
220
244
timeout = self .login_timeout ,
221
245
backoff_generator = self ._backoff_generator ,
222
246
)
247
+ elif self ._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR :
248
+ self ._session_parameters [
249
+ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL
250
+ ] = (self ._client_store_temporary_credential if IS_LINUX else True )
251
+ auth .read_temporary_credentials (
252
+ self .host ,
253
+ self .user ,
254
+ self ._session_parameters ,
255
+ )
256
+ # Depending on whether self._rest.id_token is available we do different
257
+ # auth_instance
258
+ if self ._rest .id_token is None :
259
+ self .auth_class = AuthByWebBrowser (
260
+ application = self .application ,
261
+ protocol = self ._protocol ,
262
+ host = self .host ,
263
+ port = self .port ,
264
+ timeout = self .login_timeout ,
265
+ backoff_generator = self ._backoff_generator ,
266
+ )
267
+ else :
268
+ self .auth_class = AuthByIdToken (
269
+ id_token = self ._rest .id_token ,
270
+ application = self .application ,
271
+ protocol = self ._protocol ,
272
+ host = self .host ,
273
+ port = self .port ,
274
+ timeout = self .login_timeout ,
275
+ backoff_generator = self ._backoff_generator ,
276
+ )
277
+
278
+ elif self ._authenticator == KEY_PAIR_AUTHENTICATOR :
279
+ private_key = self ._private_key
280
+
281
+ if self ._private_key_file :
282
+ private_key = _get_private_bytes_from_file (
283
+ self ._private_key_file ,
284
+ self ._private_key_file_pwd ,
285
+ )
286
+
287
+ self .auth_class = AuthByKeyPair (
288
+ private_key = private_key ,
289
+ timeout = self .login_timeout ,
290
+ backoff_generator = self ._backoff_generator ,
291
+ )
292
+ elif self ._authenticator == OAUTH_AUTHENTICATOR :
293
+ self .auth_class = AuthByOAuth (
294
+ oauth_token = self ._token ,
295
+ timeout = self .login_timeout ,
296
+ backoff_generator = self ._backoff_generator ,
297
+ )
298
+ elif self ._authenticator == USR_PWD_MFA_AUTHENTICATOR :
299
+ self ._session_parameters [PARAMETER_CLIENT_REQUEST_MFA_TOKEN ] = (
300
+ self ._client_request_mfa_token if IS_LINUX else True
301
+ )
302
+ if self ._session_parameters [PARAMETER_CLIENT_REQUEST_MFA_TOKEN ]:
303
+ auth .read_temporary_credentials (
304
+ self .host ,
305
+ self .user ,
306
+ self ._session_parameters ,
307
+ )
308
+ self .auth_class = AuthByUsrPwdMfa (
309
+ password = self ._password ,
310
+ mfa_token = self .rest .mfa_token ,
311
+ timeout = self .login_timeout ,
312
+ backoff_generator = self ._backoff_generator ,
313
+ )
223
314
else :
224
- raise NotImplementedError (
225
- f"asyncio support for authenticator is not supported { self ._authenticator } "
315
+ # okta URL, e.g., https://<account>.okta.com/
316
+ self .auth_class = AuthByOkta (
317
+ application = self .application ,
318
+ timeout = self .login_timeout ,
319
+ backoff_generator = self ._backoff_generator ,
226
320
)
227
- # TODO: asyncio support for other authenticators
321
+
228
322
await self .authenticate_with_retry (self .auth_class )
229
323
230
324
self ._password = None # ensure password won't persist
0 commit comments