1- #!/usr/bin/env python
21#
32# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
43#
1615from os import getenv , makedirs , mkdir , path , remove , removedirs , rmdir
1716from os .path import expanduser
1817from threading import Lock , Thread
18+ from typing import TYPE_CHECKING , Any , Callable
1919
2020from cryptography .hazmat .backends import default_backend
2121from cryptography .hazmat .primitives .serialization import (
2626 load_pem_private_key ,
2727)
2828
29- from .auth_keypair import AuthByKeyPair
30- from .auth_usrpwdmfa import AuthByUsrPwdMfa
31- from .compat import IS_LINUX , IS_MACOS , IS_WINDOWS , urlencode
32- from .constants import (
29+ from ..compat import IS_LINUX , IS_MACOS , IS_WINDOWS , urlencode
30+ from ..constants import (
31+ DAY_IN_SECONDS ,
3332 HTTP_HEADER_ACCEPT ,
3433 HTTP_HEADER_CONTENT_TYPE ,
3534 HTTP_HEADER_SERVICE_NAME ,
3635 HTTP_HEADER_USER_AGENT ,
3736 PARAMETER_CLIENT_REQUEST_MFA_TOKEN ,
3837 PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL ,
3938)
40- from .description import (
39+ from .. description import (
4140 COMPILER ,
4241 IMPLEMENTATION ,
4342 OPERATING_SYSTEM ,
4443 PLATFORM ,
4544 PYTHON_VERSION ,
4645)
47- from .errorcode import ER_FAILED_TO_CONNECT_TO_DB
48- from .errors import (
46+ from .. errorcode import ER_FAILED_TO_CONNECT_TO_DB
47+ from .. errors import (
4948 BadGatewayError ,
5049 DatabaseError ,
5150 Error ,
5251 ForbiddenError ,
5352 ProgrammingError ,
5453 ServiceUnavailableError ,
5554)
56- from .network import (
55+ from .. network import (
5756 ACCEPT_TYPE_APPLICATION_SNOWFLAKE ,
5857 CONTENT_TYPE_APPLICATION_JSON ,
5958 ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE ,
60- KEY_PAIR_AUTHENTICATOR ,
6159 PYTHON_CONNECTOR_USER_AGENT ,
6260 ReauthenticationRequest ,
6361)
64- from .options import installed_keyring , keyring
65- from .sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
66- from .version import VERSION
62+ from ..options import installed_keyring , keyring
63+ from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
64+ from ..version import VERSION
65+
66+ if TYPE_CHECKING :
67+ from . import AuthByPlugin
6768
6869logger = logging .getLogger (__name__ )
6970
@@ -153,19 +154,19 @@ def base_auth_data(
153154
154155 def authenticate (
155156 self ,
156- auth_instance ,
157- account ,
158- user ,
159- database = None ,
160- schema = None ,
161- warehouse = None ,
162- role = None ,
163- passcode = None ,
164- passcode_in_password = False ,
165- mfa_callback = None ,
166- password_callback = None ,
167- session_parameters = None ,
168- timeout = 120 ,
157+ auth_instance : AuthByPlugin ,
158+ account : str ,
159+ user : str ,
160+ database : str | None = None ,
161+ schema : str | None = None ,
162+ warehouse : str | None = None ,
163+ role : str | None = None ,
164+ passcode : str | None = None ,
165+ passcode_in_password : bool = False ,
166+ mfa_callback : Callable [[], None ] | None = None ,
167+ password_callback : Callable [[], str ] | None = None ,
168+ session_parameters : dict [ Any , Any ] | None = None ,
169+ timeout : int = 120 ,
169170 ) -> dict [str , str | int | bool ]:
170171 logger .debug ("authenticate" )
171172
@@ -242,15 +243,7 @@ def authenticate(
242243 # login_timeout comes from user configuration.
243244 # Between login timeout and auth specific
244245 # timeout use whichever value is smaller
245- if hasattr (auth_instance , "get_timeout" ):
246- logger .debug (
247- f"Authenticator, { type (auth_instance ).__name__ } , implements get_timeout"
248- )
249- auth_timeout = min (
250- self ._rest ._connection .login_timeout , auth_instance .get_timeout ()
251- )
252- else :
253- auth_timeout = self ._rest ._connection .login_timeout
246+ auth_timeout = min (self ._rest ._connection .login_timeout , auth_instance .timeout )
254247 logger .debug (f"Timeout set to { auth_timeout } " )
255248
256249 try :
@@ -386,15 +379,19 @@ def post_request_wrapper(self, url, headers, body):
386379 )
387380 )
388381
389- if type (auth_instance ) is AuthByKeyPair :
382+ from . import AuthByKeyPair
383+
384+ if isinstance (auth_instance , AuthByKeyPair ):
390385 logger .debug (
391386 "JWT Token authentication failed. "
392387 "Token expires at: %s. "
393388 "Current Time: %s" ,
394389 str (auth_instance ._jwt_token_exp ),
395390 str (datetime .utcnow ()),
396391 )
397- if type (auth_instance ) is AuthByUsrPwdMfa :
392+ from . import AuthByUsrPwdMfa
393+
394+ if isinstance (auth_instance , AuthByUsrPwdMfa ):
398395 delete_temporary_credential (self ._rest ._host , user , MFA_TOKEN )
399396 Error .errorhandler_wrapper (
400397 self ._rest ._connection ,
@@ -483,16 +480,33 @@ def _read_temporary_credential(self, host, user, cred_type):
483480 logger .debug ("OS not supported for Local Secure Storage" )
484481 return cred
485482
486- def read_temporary_credentials (self , host , user , session_parameters ):
483+ def read_temporary_credentials (
484+ self ,
485+ host : str ,
486+ user : str ,
487+ session_parameters : dict [str , Any ],
488+ ) -> None :
487489 if session_parameters .get (PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False ):
488- self ._rest .id_token = self ._read_temporary_credential (host , user , ID_TOKEN )
490+ self ._rest .id_token = self ._read_temporary_credential (
491+ host ,
492+ user ,
493+ ID_TOKEN ,
494+ )
489495
490496 if session_parameters .get (PARAMETER_CLIENT_REQUEST_MFA_TOKEN , False ):
491497 self ._rest .mfa_token = self ._read_temporary_credential (
492- host , user , MFA_TOKEN
498+ host ,
499+ user ,
500+ MFA_TOKEN ,
493501 )
494502
495- def _write_temporary_credential (self , host , user , cred_type , cred ):
503+ def _write_temporary_credential (
504+ self ,
505+ host : str ,
506+ user : str ,
507+ cred_type : str ,
508+ cred : str | None ,
509+ ) -> None :
496510 if not cred :
497511 logger .debug (
498512 "no credential is given when try to store temporary credential"
@@ -522,9 +536,18 @@ def _write_temporary_credential(self, host, user, cred_type, cred):
522536 else :
523537 logger .debug ("OS not supported for Local Secure Storage" )
524538
525- def write_temporary_credentials (self , host , user , session_parameters , response ):
526- if self ._rest ._connection .consent_cache_id_token and session_parameters .get (
527- PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False
539+ def write_temporary_credentials (
540+ self ,
541+ host : str ,
542+ user : str ,
543+ session_parameters : dict [str , Any ],
544+ response : dict [str , Any ],
545+ ) -> None :
546+ if (
547+ self ._rest ._connection .auth_class .consent_cache_id_token
548+ and session_parameters .get (
549+ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False
550+ )
528551 ):
529552 self ._write_temporary_credential (
530553 host , user , ID_TOKEN , response ["data" ].get ("idToken" )
@@ -534,10 +557,9 @@ def write_temporary_credentials(self, host, user, session_parameters, response):
534557 self ._write_temporary_credential (
535558 host , user , MFA_TOKEN , response ["data" ].get ("mfaToken" )
536559 )
537- return
538560
539561
540- def flush_temporary_credentials ():
562+ def flush_temporary_credentials () -> None :
541563 """Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK."""
542564 global TEMPORARY_CREDENTIAL
543565 global TEMPORARY_CREDENTIAL_FILE
@@ -566,7 +588,7 @@ def flush_temporary_credentials():
566588 unlock_temporary_credential_file ()
567589
568590
569- def write_temporary_credential_file (host , cred_name , cred ):
591+ def write_temporary_credential_file (host , cred_name , cred ) -> None :
570592 """Writes temporary credential file when OS is Linux."""
571593 if not CACHE_DIR :
572594 # no cache is enabled
@@ -581,7 +603,7 @@ def write_temporary_credential_file(host, cred_name, cred):
581603 flush_temporary_credentials ()
582604
583605
584- def read_temporary_credential_file ():
606+ def read_temporary_credential_file () -> None :
585607 """Reads temporary credential file when OS is Linux."""
586608 if not CACHE_DIR :
587609 # no cache is enabled
@@ -616,10 +638,9 @@ def read_temporary_credential_file():
616638 )
617639 finally :
618640 unlock_temporary_credential_file ()
619- return None
620641
621642
622- def lock_temporary_credential_file ():
643+ def lock_temporary_credential_file () -> bool :
623644 global TEMPORARY_CREDENTIAL_FILE_LOCK
624645 try :
625646 mkdir (TEMPORARY_CREDENTIAL_FILE_LOCK )
@@ -632,7 +653,7 @@ def lock_temporary_credential_file():
632653 return False
633654
634655
635- def unlock_temporary_credential_file ():
656+ def unlock_temporary_credential_file () -> bool :
636657 global TEMPORARY_CREDENTIAL_FILE_LOCK
637658 try :
638659 rmdir (TEMPORARY_CREDENTIAL_FILE_LOCK )
@@ -709,10 +730,13 @@ def get_token_from_private_key(
709730 format = PrivateFormat .PKCS8 ,
710731 encryption_algorithm = NoEncryption (),
711732 )
712- auth_instance = AuthByKeyPair (private_key , 1440 * 60 ) # token valid for 24 hours
713- return auth_instance .authenticate (
714- KEY_PAIR_AUTHENTICATOR , None , account , user , key_password
715- )
733+ from . import AuthByKeyPair
734+
735+ auth_instance = AuthByKeyPair (
736+ private_key ,
737+ DAY_IN_SECONDS ,
738+ ) # token valid for 24 hours
739+ return auth_instance .prepare (account = account , user = user )
716740
717741
718742def get_public_key_fingerprint (private_key_file : str , password : str ) -> str :
@@ -729,4 +753,6 @@ def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
729753 private_key = load_der_private_key (
730754 data = private_key , password = None , backend = default_backend ()
731755 )
756+ from . import AuthByKeyPair
757+
732758 return AuthByKeyPair .calculate_public_key_fingerprint (private_key )
0 commit comments