1818from abc import ABC , abstractmethod
1919from enum import Enum , unique
2020from os import getenv
21- from typing import TYPE_CHECKING , Any
21+ from typing import TYPE_CHECKING , Any , Iterator
2222
2323from ..errorcode import ER_FAILED_TO_CONNECT_TO_DB
2424from ..errors import DatabaseError , Error , OperationalError
2525from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
26- from ..time_util import DecorrelateJitterBackoff
26+ from ..time_util import TimeoutBackoffCtx
2727
2828if TYPE_CHECKING :
2929 from .. import SnowflakeConnection
3030
3131logger = logging .getLogger (__name__ )
3232
33+ """
34+ Default value for max retry is 1 because
35+ Python requests module already tries twice
36+ by default. Unlike JWT where we need to refresh
37+ token every 10 seconds, general authenticators
38+ wait for 60 seconds before connection timeout
39+ per attempt totaling a 240 sec wait time for a non
40+ JWT based authenticator which is more than enough.
41+ This can be changed ofcourse using MAX_CNXN_RETRY_ATTEMPTS
42+ env variable.
43+ """
3344DEFAULT_MAX_CON_RETRY_ATTEMPTS = 1
34-
35-
36- class AuthRetryCtx :
37- def __init__ (self ) -> None :
38- self ._current_retry_count = 0
39- self ._max_retry_attempts = int (
40- getenv ("MAX_CON_RETRY_ATTEMPTS" , DEFAULT_MAX_CON_RETRY_ATTEMPTS )
41- )
42- self ._backoff = DecorrelateJitterBackoff (1 , 16 )
43- self ._current_sleep_time = 1
44-
45- def get_current_retry_count (self ) -> int :
46- return self ._current_retry_count
47-
48- def increment_retry (self ) -> None :
49- self ._current_retry_count += 1
50-
51- def should_retry (self ) -> bool :
52- """Decides whether to retry connection.
53-
54- Default value for max retry is 1 because
55- Python requests module already tries twice
56- by default. Unlike JWT where we need to refresh
57- token every 10 seconds, general authenticators
58- wait for 60 seconds before connection timeout
59- per attempt totaling a 240 sec wait time for a non
60- JWT based authenticator which is more than enough.
61- This can be changed ofcourse using MAX_CNXN_RETRY_ATTEMPTS
62- env variable.
63- """
64- return self ._current_retry_count < self ._max_retry_attempts
65-
66- def next_sleep_duration (self ) -> int :
67- self ._current_sleep_time = self ._backoff .next_sleep (
68- self ._current_retry_count , self ._current_sleep_time
69- )
70- logger .debug (f"Sleeping for { self ._current_sleep_time } seconds" )
71- return self ._current_sleep_time
72-
73- def reset (self ) -> None :
74- self ._current_retry_count = 0
75- self ._current_sleep_time = 1
45+ DEFAULT_AUTH_CLASS_TIMEOUT = 120
7646
7747
7848@unique
@@ -89,18 +59,38 @@ class AuthType(Enum):
8959class AuthByPlugin (ABC ):
9060 """External Authenticator interface."""
9161
92- def __init__ (self ) -> None :
93- self ._retry_ctx = AuthRetryCtx ()
62+ def __init__ (
63+ self ,
64+ timeout : int | None = None ,
65+ backoff_generator : Iterator | None = None ,
66+ ** kwargs ,
67+ ) -> None :
9468 self .consent_cache_id_token = False
95- self ._timeout : int = 120
69+
70+ self ._retry_ctx = TimeoutBackoffCtx (
71+ timeout = timeout if timeout is not None else DEFAULT_AUTH_CLASS_TIMEOUT ,
72+ max_retry_attempts = kwargs .get (
73+ "max_retry_attempts" ,
74+ int (getenv ("MAX_CON_RETRY_ATTEMPTS" , DEFAULT_MAX_CON_RETRY_ATTEMPTS )),
75+ ),
76+ backoff_generator = backoff_generator ,
77+ )
78+
79+ # some authenticators may want to override socket level timeout
80+ # for example, AuthByKeyPair will set this to ensure JWT tokens are refreshed in time
81+ # if not None, this will override socket_timeout specified in connection
82+ self ._socket_timeout = None
9683
9784 @property
9885 def timeout (self ) -> int :
99- return self ._timeout
86+ """The timeout of _retry_ctx is guaranteed not to be None during AuthByPlugin initialization"""
87+ return self ._retry_ctx .timeout
10088
10189 @timeout .setter
102- def timeout (self , value : Any ) -> None :
103- self ._timeout = int (value )
90+ def timeout (self ) -> None :
91+ logger .warning (
92+ "Attempting to mutate timeout of AuthByPlugin. Create a new instance with desired parameters instead."
93+ )
10494
10595 @property
10696 @abstractmethod
@@ -207,20 +197,23 @@ def handle_timeout(
207197 time ranges between 1 and 16 seconds.
208198 """
209199
210- del authenticator , service_name , account , user , password
200+ # Some authenticators may not want to delete the parameters to this function
201+ # Currently, the only authenticator where this is the case is AuthByKeyPair
202+ if kwargs .pop ("delete_params" , True ):
203+ del authenticator , service_name , account , user , password
204+
211205 logger .debug ("Default timeout handler invoked for authenticator" )
212- if not self ._retry_ctx .should_retry () :
206+ if not self ._retry_ctx .should_retry :
213207 error = OperationalError (
214- msg = f"Could not connect to Snowflake backend after { self ._retry_ctx .get_current_retry_count () } attempt(s)."
208+ msg = f"Could not connect to Snowflake backend after { self ._retry_ctx .current_retry_count + 1 } attempt(s)."
215209 "Aborting" ,
216210 errno = ER_FAILED_TO_CONNECT_TO_DB ,
217211 )
218- self ._retry_ctx .reset ()
219212 raise error
220213 else :
221214 logger .debug (
222- f"Hit connection timeout, attempt number { self ._retry_ctx .get_current_retry_count () } ."
215+ f"Hit connection timeout, attempt number { self ._retry_ctx .current_retry_count + 1 } ."
223216 " Will retry in a bit..."
224217 )
225- self ._retry_ctx .increment_retry ( )
226- time . sleep ( self ._retry_ctx .next_sleep_duration () )
218+ time . sleep ( float ( self ._retry_ctx .current_sleep_time ) )
219+ self ._retry_ctx .increment ( )
0 commit comments