1+ #!/usr/bin/env python
12#
23# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
34#
1516from os import getenv , makedirs , mkdir , path , remove , removedirs , rmdir
1617from os .path import expanduser
1718from threading import Lock , Thread
18- from typing import TYPE_CHECKING , Any , Callable
1919
2020from cryptography .hazmat .backends import default_backend
2121from cryptography .hazmat .primitives .serialization import (
2626 load_pem_private_key ,
2727)
2828
29- from ..compat import IS_LINUX , IS_MACOS , IS_WINDOWS , urlencode
30- from ..constants import (
31- DAY_IN_SECONDS ,
29+ from .auth_keypair import AuthByKeyPair
30+ from .auth_usrpwdmfa import AuthByUsrPwdMfa
31+ from .compat import IS_LINUX , IS_MACOS , IS_WINDOWS , urlencode
32+ from .constants import (
3233 HTTP_HEADER_ACCEPT ,
3334 HTTP_HEADER_CONTENT_TYPE ,
3435 HTTP_HEADER_SERVICE_NAME ,
3536 HTTP_HEADER_USER_AGENT ,
3637 PARAMETER_CLIENT_REQUEST_MFA_TOKEN ,
3738 PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL ,
3839)
39- from .. description import (
40+ from .description import (
4041 COMPILER ,
4142 IMPLEMENTATION ,
4243 OPERATING_SYSTEM ,
4344 PLATFORM ,
4445 PYTHON_VERSION ,
4546)
46- from .. errorcode import ER_FAILED_TO_CONNECT_TO_DB
47- from .. errors import (
47+ from .errorcode import ER_FAILED_TO_CONNECT_TO_DB
48+ from .errors import (
4849 BadGatewayError ,
4950 DatabaseError ,
5051 Error ,
5152 ForbiddenError ,
5253 ProgrammingError ,
5354 ServiceUnavailableError ,
5455)
55- from .. network import (
56+ from .network import (
5657 ACCEPT_TYPE_APPLICATION_SNOWFLAKE ,
5758 CONTENT_TYPE_APPLICATION_JSON ,
5859 ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE ,
60+ KEY_PAIR_AUTHENTICATOR ,
5961 PYTHON_CONNECTOR_USER_AGENT ,
6062 ReauthenticationRequest ,
6163)
62- from ..options import installed_keyring , keyring
63- from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
64- from ..version import VERSION
65-
66- if TYPE_CHECKING :
67- from . import AuthByPlugin
64+ from .options import installed_keyring , keyring
65+ from .sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
66+ from .version import VERSION
6867
6968logger = logging .getLogger (__name__ )
7069
@@ -154,19 +153,19 @@ def base_auth_data(
154153
155154 def authenticate (
156155 self ,
157- auth_instance : AuthByPlugin ,
158- account : str ,
159- user : str ,
160- database : str | None = None ,
161- schema : str | None = None ,
162- warehouse : str | None = None ,
163- role : str | None = None ,
164- passcode : str | None = None ,
165- passcode_in_password : bool = False ,
166- mfa_callback : Callable [[], None ] | None = None ,
167- password_callback : Callable [[], str ] | None = None ,
168- session_parameters : dict [ Any , Any ] | None = None ,
169- timeout : int = 120 ,
156+ auth_instance ,
157+ account ,
158+ user ,
159+ database = None ,
160+ schema = None ,
161+ warehouse = None ,
162+ role = None ,
163+ passcode = None ,
164+ passcode_in_password = False ,
165+ mfa_callback = None ,
166+ password_callback = None ,
167+ session_parameters = None ,
168+ timeout = 120 ,
170169 ) -> dict [str , str | int | bool ]:
171170 logger .debug ("authenticate" )
172171
@@ -243,7 +242,15 @@ def authenticate(
243242 # login_timeout comes from user configuration.
244243 # Between login timeout and auth specific
245244 # timeout use whichever value is smaller
246- auth_timeout = min (self ._rest ._connection .login_timeout , auth_instance .timeout )
245+ if hasattr (auth_instance , "get_timeout" ):
246+ logger .debug (
247+ f"Authenticator, { type (auth_instance ).__name__ } , implements get_timeout"
248+ )
249+ auth_timeout = min (
250+ self ._rest ._connection .login_timeout , auth_instance .get_timeout ()
251+ )
252+ else :
253+ auth_timeout = self ._rest ._connection .login_timeout
247254 logger .debug (f"Timeout set to { auth_timeout } " )
248255
249256 try :
@@ -379,19 +386,15 @@ def post_request_wrapper(self, url, headers, body):
379386 )
380387 )
381388
382- from . import AuthByKeyPair
383-
384- if isinstance (auth_instance , AuthByKeyPair ):
389+ if type (auth_instance ) is AuthByKeyPair :
385390 logger .debug (
386391 "JWT Token authentication failed. "
387392 "Token expires at: %s. "
388393 "Current Time: %s" ,
389394 str (auth_instance ._jwt_token_exp ),
390395 str (datetime .utcnow ()),
391396 )
392- from . import AuthByUsrPwdMfa
393-
394- if isinstance (auth_instance , AuthByUsrPwdMfa ):
397+ if type (auth_instance ) is AuthByUsrPwdMfa :
395398 delete_temporary_credential (self ._rest ._host , user , MFA_TOKEN )
396399 Error .errorhandler_wrapper (
397400 self ._rest ._connection ,
@@ -480,33 +483,16 @@ def _read_temporary_credential(self, host, user, cred_type):
480483 logger .debug ("OS not supported for Local Secure Storage" )
481484 return cred
482485
483- def read_temporary_credentials (
484- self ,
485- host : str ,
486- user : str ,
487- session_parameters : dict [str , Any ],
488- ) -> None :
486+ def read_temporary_credentials (self , host , user , session_parameters ):
489487 if session_parameters .get (PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False ):
490- self ._rest .id_token = self ._read_temporary_credential (
491- host ,
492- user ,
493- ID_TOKEN ,
494- )
488+ self ._rest .id_token = self ._read_temporary_credential (host , user , ID_TOKEN )
495489
496490 if session_parameters .get (PARAMETER_CLIENT_REQUEST_MFA_TOKEN , False ):
497491 self ._rest .mfa_token = self ._read_temporary_credential (
498- host ,
499- user ,
500- MFA_TOKEN ,
492+ host , user , MFA_TOKEN
501493 )
502494
503- def _write_temporary_credential (
504- self ,
505- host : str ,
506- user : str ,
507- cred_type : str ,
508- cred : str | None ,
509- ) -> None :
495+ def _write_temporary_credential (self , host , user , cred_type , cred ):
510496 if not cred :
511497 logger .debug (
512498 "no credential is given when try to store temporary credential"
@@ -536,18 +522,9 @@ def _write_temporary_credential(
536522 else :
537523 logger .debug ("OS not supported for Local Secure Storage" )
538524
539- def write_temporary_credentials (
540- self ,
541- host : str ,
542- user : str ,
543- session_parameters : dict [str , Any ],
544- response : dict [str , Any ],
545- ) -> None :
546- if (
547- self ._rest ._connection .auth_class .consent_cache_id_token
548- and session_parameters .get (
549- PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False
550- )
525+ def write_temporary_credentials (self , host , user , session_parameters , response ):
526+ if self ._rest ._connection .consent_cache_id_token and session_parameters .get (
527+ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False
551528 ):
552529 self ._write_temporary_credential (
553530 host , user , ID_TOKEN , response ["data" ].get ("idToken" )
@@ -557,9 +534,10 @@ def write_temporary_credentials(
557534 self ._write_temporary_credential (
558535 host , user , MFA_TOKEN , response ["data" ].get ("mfaToken" )
559536 )
537+ return
560538
561539
562- def flush_temporary_credentials () -> None :
540+ def flush_temporary_credentials ():
563541 """Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK."""
564542 global TEMPORARY_CREDENTIAL
565543 global TEMPORARY_CREDENTIAL_FILE
@@ -588,7 +566,7 @@ def flush_temporary_credentials() -> None:
588566 unlock_temporary_credential_file ()
589567
590568
591- def write_temporary_credential_file (host , cred_name , cred ) -> None :
569+ def write_temporary_credential_file (host , cred_name , cred ):
592570 """Writes temporary credential file when OS is Linux."""
593571 if not CACHE_DIR :
594572 # no cache is enabled
@@ -603,7 +581,7 @@ def write_temporary_credential_file(host, cred_name, cred) -> None:
603581 flush_temporary_credentials ()
604582
605583
606- def read_temporary_credential_file () -> None :
584+ def read_temporary_credential_file ():
607585 """Reads temporary credential file when OS is Linux."""
608586 if not CACHE_DIR :
609587 # no cache is enabled
@@ -638,9 +616,10 @@ def read_temporary_credential_file() -> None:
638616 )
639617 finally :
640618 unlock_temporary_credential_file ()
619+ return None
641620
642621
643- def lock_temporary_credential_file () -> bool :
622+ def lock_temporary_credential_file ():
644623 global TEMPORARY_CREDENTIAL_FILE_LOCK
645624 try :
646625 mkdir (TEMPORARY_CREDENTIAL_FILE_LOCK )
@@ -653,7 +632,7 @@ def lock_temporary_credential_file() -> bool:
653632 return False
654633
655634
656- def unlock_temporary_credential_file () -> bool :
635+ def unlock_temporary_credential_file ():
657636 global TEMPORARY_CREDENTIAL_FILE_LOCK
658637 try :
659638 rmdir (TEMPORARY_CREDENTIAL_FILE_LOCK )
@@ -730,13 +709,10 @@ def get_token_from_private_key(
730709 format = PrivateFormat .PKCS8 ,
731710 encryption_algorithm = NoEncryption (),
732711 )
733- from . import AuthByKeyPair
734-
735- auth_instance = AuthByKeyPair (
736- private_key ,
737- DAY_IN_SECONDS ,
738- ) # token valid for 24 hours
739- return auth_instance .prepare (account = account , user = user )
712+ auth_instance = AuthByKeyPair (private_key , 1440 * 60 ) # token valid for 24 hours
713+ return auth_instance .authenticate (
714+ KEY_PAIR_AUTHENTICATOR , None , account , user , key_password
715+ )
740716
741717
742718def get_public_key_fingerprint (private_key_file : str , password : str ) -> str :
@@ -753,6 +729,4 @@ def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
753729 private_key = load_der_private_key (
754730 data = private_key , password = None , backend = default_backend ()
755731 )
756- from . import AuthByKeyPair
757-
758732 return AuthByKeyPair .calculate_public_key_fingerprint (private_key )
0 commit comments