1- #!/usr/bin/env python 
21# 
32# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. 
43# 
1615from  os  import  getenv , makedirs , mkdir , path , remove , removedirs , rmdir 
1716from  os .path  import  expanduser 
1817from  threading  import  Lock , Thread 
18+ from  typing  import  TYPE_CHECKING , Any , Callable 
1919
2020from  cryptography .hazmat .backends  import  default_backend 
2121from  cryptography .hazmat .primitives .serialization  import  (
2626    load_pem_private_key ,
2727)
2828
29- from  .auth_keypair  import  AuthByKeyPair 
30- from  .auth_usrpwdmfa  import  AuthByUsrPwdMfa 
31- from  .compat  import  IS_LINUX , IS_MACOS , IS_WINDOWS , urlencode 
32- from  .constants  import  (
29+ from  ..compat  import  IS_LINUX , IS_MACOS , IS_WINDOWS , urlencode 
30+ from  ..constants  import  (
31+     DAY_IN_SECONDS ,
3332    HTTP_HEADER_ACCEPT ,
3433    HTTP_HEADER_CONTENT_TYPE ,
3534    HTTP_HEADER_SERVICE_NAME ,
3635    HTTP_HEADER_USER_AGENT ,
3736    PARAMETER_CLIENT_REQUEST_MFA_TOKEN ,
3837    PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL ,
3938)
40- from  .description  import  (
39+ from  .. description  import  (
4140    COMPILER ,
4241    IMPLEMENTATION ,
4342    OPERATING_SYSTEM ,
4443    PLATFORM ,
4544    PYTHON_VERSION ,
4645)
47- from  .errorcode  import  ER_FAILED_TO_CONNECT_TO_DB 
48- from  .errors  import  (
46+ from  .. errorcode  import  ER_FAILED_TO_CONNECT_TO_DB 
47+ from  .. errors  import  (
4948    BadGatewayError ,
5049    DatabaseError ,
5150    Error ,
5251    ForbiddenError ,
5352    ProgrammingError ,
5453    ServiceUnavailableError ,
5554)
56- from  .network  import  (
55+ from  .. network  import  (
5756    ACCEPT_TYPE_APPLICATION_SNOWFLAKE ,
5857    CONTENT_TYPE_APPLICATION_JSON ,
5958    ID_TOKEN_INVALID_LOGIN_REQUEST_GS_CODE ,
60-     KEY_PAIR_AUTHENTICATOR ,
6159    PYTHON_CONNECTOR_USER_AGENT ,
6260    ReauthenticationRequest ,
6361)
64- from  .options  import  installed_keyring , keyring 
65- from  .sqlstate  import  SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED 
66- from  .version  import  VERSION 
62+ from  ..options  import  installed_keyring , keyring 
63+ from  ..sqlstate  import  SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED 
64+ from  ..version  import  VERSION 
65+ 
66+ if  TYPE_CHECKING :
67+     from  . import  AuthByPlugin 
6768
6869logger  =  logging .getLogger (__name__ )
6970
@@ -153,19 +154,19 @@ def base_auth_data(
153154
154155    def  authenticate (
155156        self ,
156-         auth_instance ,
157-         account ,
158-         user ,
159-         database = None ,
160-         schema = None ,
161-         warehouse = None ,
162-         role = None ,
163-         passcode = None ,
164-         passcode_in_password = False ,
165-         mfa_callback = None ,
166-         password_callback = None ,
167-         session_parameters = None ,
168-         timeout = 120 ,
157+         auth_instance :  AuthByPlugin ,
158+         account :  str ,
159+         user :  str ,
160+         database :  str   |   None   =   None ,
161+         schema :  str   |   None   =   None ,
162+         warehouse :  str   |   None   =   None ,
163+         role :  str   |   None   =   None ,
164+         passcode :  str   |   None   =   None ,
165+         passcode_in_password :  bool   =   False ,
166+         mfa_callback :  Callable [[],  None ]  |   None   =   None ,
167+         password_callback :  Callable [[],  str ]  |   None   =   None ,
168+         session_parameters :  dict [ Any ,  Any ]  |   None   =   None ,
169+         timeout :  int   =   120 ,
169170    ) ->  dict [str , str  |  int  |  bool ]:
170171        logger .debug ("authenticate" )
171172
@@ -242,15 +243,7 @@ def authenticate(
242243        # login_timeout comes from user configuration. 
243244        # Between login timeout and auth specific 
244245        # timeout use whichever value is smaller 
245-         if  hasattr (auth_instance , "get_timeout" ):
246-             logger .debug (
247-                 f"Authenticator, { type (auth_instance ).__name__ }  , implements get_timeout" 
248-             )
249-             auth_timeout  =  min (
250-                 self ._rest ._connection .login_timeout , auth_instance .get_timeout ()
251-             )
252-         else :
253-             auth_timeout  =  self ._rest ._connection .login_timeout 
246+         auth_timeout  =  min (self ._rest ._connection .login_timeout , auth_instance .timeout )
254247        logger .debug (f"Timeout set to { auth_timeout }  " )
255248
256249        try :
@@ -386,15 +379,19 @@ def post_request_wrapper(self, url, headers, body):
386379                    )
387380                )
388381
389-             if  type (auth_instance ) is  AuthByKeyPair :
382+             from  . import  AuthByKeyPair 
383+ 
384+             if  isinstance (auth_instance , AuthByKeyPair ):
390385                logger .debug (
391386                    "JWT Token authentication failed. " 
392387                    "Token expires at: %s. " 
393388                    "Current Time: %s" ,
394389                    str (auth_instance ._jwt_token_exp ),
395390                    str (datetime .utcnow ()),
396391                )
397-             if  type (auth_instance ) is  AuthByUsrPwdMfa :
392+             from  . import  AuthByUsrPwdMfa 
393+ 
394+             if  isinstance (auth_instance , AuthByUsrPwdMfa ):
398395                delete_temporary_credential (self ._rest ._host , user , MFA_TOKEN )
399396            Error .errorhandler_wrapper (
400397                self ._rest ._connection ,
@@ -483,16 +480,33 @@ def _read_temporary_credential(self, host, user, cred_type):
483480            logger .debug ("OS not supported for Local Secure Storage" )
484481        return  cred 
485482
486-     def  read_temporary_credentials (self , host , user , session_parameters ):
483+     def  read_temporary_credentials (
484+         self ,
485+         host : str ,
486+         user : str ,
487+         session_parameters : dict [str , Any ],
488+     ) ->  None :
487489        if  session_parameters .get (PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False ):
488-             self ._rest .id_token  =  self ._read_temporary_credential (host , user , ID_TOKEN )
490+             self ._rest .id_token  =  self ._read_temporary_credential (
491+                 host ,
492+                 user ,
493+                 ID_TOKEN ,
494+             )
489495
490496        if  session_parameters .get (PARAMETER_CLIENT_REQUEST_MFA_TOKEN , False ):
491497            self ._rest .mfa_token  =  self ._read_temporary_credential (
492-                 host , user , MFA_TOKEN 
498+                 host ,
499+                 user ,
500+                 MFA_TOKEN ,
493501            )
494502
495-     def  _write_temporary_credential (self , host , user , cred_type , cred ):
503+     def  _write_temporary_credential (
504+         self ,
505+         host : str ,
506+         user : str ,
507+         cred_type : str ,
508+         cred : str  |  None ,
509+     ) ->  None :
496510        if  not  cred :
497511            logger .debug (
498512                "no credential is given when try to store temporary credential" 
@@ -522,9 +536,18 @@ def _write_temporary_credential(self, host, user, cred_type, cred):
522536        else :
523537            logger .debug ("OS not supported for Local Secure Storage" )
524538
525-     def  write_temporary_credentials (self , host , user , session_parameters , response ):
526-         if  self ._rest ._connection .consent_cache_id_token  and  session_parameters .get (
527-             PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False 
539+     def  write_temporary_credentials (
540+         self ,
541+         host : str ,
542+         user : str ,
543+         session_parameters : dict [str , Any ],
544+         response : dict [str , Any ],
545+     ) ->  None :
546+         if  (
547+             self ._rest ._connection .auth_class .consent_cache_id_token 
548+             and  session_parameters .get (
549+                 PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL , False 
550+             )
528551        ):
529552            self ._write_temporary_credential (
530553                host , user , ID_TOKEN , response ["data" ].get ("idToken" )
@@ -534,10 +557,9 @@ def write_temporary_credentials(self, host, user, session_parameters, response):
534557            self ._write_temporary_credential (
535558                host , user , MFA_TOKEN , response ["data" ].get ("mfaToken" )
536559            )
537-         return 
538560
539561
540- def  flush_temporary_credentials ():
562+ def  flush_temporary_credentials ()  ->   None :
541563    """Flush temporary credentials in memory into disk. Need to hold TEMPORARY_CREDENTIAL_LOCK.""" 
542564    global  TEMPORARY_CREDENTIAL 
543565    global  TEMPORARY_CREDENTIAL_FILE 
@@ -566,7 +588,7 @@ def flush_temporary_credentials():
566588        unlock_temporary_credential_file ()
567589
568590
569- def  write_temporary_credential_file (host , cred_name , cred ):
591+ def  write_temporary_credential_file (host , cred_name , cred )  ->   None :
570592    """Writes temporary credential file when OS is Linux.""" 
571593    if  not  CACHE_DIR :
572594        # no cache is enabled 
@@ -581,7 +603,7 @@ def write_temporary_credential_file(host, cred_name, cred):
581603        flush_temporary_credentials ()
582604
583605
584- def  read_temporary_credential_file ():
606+ def  read_temporary_credential_file ()  ->   None :
585607    """Reads temporary credential file when OS is Linux.""" 
586608    if  not  CACHE_DIR :
587609        # no cache is enabled 
@@ -616,10 +638,9 @@ def read_temporary_credential_file():
616638            )
617639        finally :
618640            unlock_temporary_credential_file ()
619-     return  None 
620641
621642
622- def  lock_temporary_credential_file ():
643+ def  lock_temporary_credential_file ()  ->   bool :
623644    global  TEMPORARY_CREDENTIAL_FILE_LOCK 
624645    try :
625646        mkdir (TEMPORARY_CREDENTIAL_FILE_LOCK )
@@ -632,7 +653,7 @@ def lock_temporary_credential_file():
632653        return  False 
633654
634655
635- def  unlock_temporary_credential_file ():
656+ def  unlock_temporary_credential_file ()  ->   bool :
636657    global  TEMPORARY_CREDENTIAL_FILE_LOCK 
637658    try :
638659        rmdir (TEMPORARY_CREDENTIAL_FILE_LOCK )
@@ -709,10 +730,13 @@ def get_token_from_private_key(
709730        format = PrivateFormat .PKCS8 ,
710731        encryption_algorithm = NoEncryption (),
711732    )
712-     auth_instance  =  AuthByKeyPair (private_key , 1440  *  60 )  # token valid for 24 hours 
713-     return  auth_instance .authenticate (
714-         KEY_PAIR_AUTHENTICATOR , None , account , user , key_password 
715-     )
733+     from  . import  AuthByKeyPair 
734+ 
735+     auth_instance  =  AuthByKeyPair (
736+         private_key ,
737+         DAY_IN_SECONDS ,
738+     )  # token valid for 24 hours 
739+     return  auth_instance .prepare (account = account , user = user )
716740
717741
718742def  get_public_key_fingerprint (private_key_file : str , password : str ) ->  str :
@@ -729,4 +753,6 @@ def get_public_key_fingerprint(private_key_file: str, password: str) -> str:
729753    private_key  =  load_der_private_key (
730754        data = private_key , password = None , backend = default_backend ()
731755    )
756+     from  . import  AuthByKeyPair 
757+ 
732758    return  AuthByKeyPair .calculate_public_key_fingerprint (private_key )
0 commit comments