18
18
from abc import ABC , abstractmethod
19
19
from enum import Enum , unique
20
20
from os import getenv
21
- from typing import TYPE_CHECKING , Any
21
+ from typing import TYPE_CHECKING , Any , Iterator
22
22
23
23
from ..errorcode import ER_FAILED_TO_CONNECT_TO_DB
24
24
from ..errors import DatabaseError , Error , OperationalError
25
25
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
26
- from ..time_util import DecorrelateJitterBackoff
26
+ from ..time_util import TimeoutBackoffCtx
27
27
28
28
if TYPE_CHECKING :
29
29
from .. import SnowflakeConnection
30
30
31
31
logger = logging .getLogger (__name__ )
32
32
33
+ """
34
+ Default value for max retry is 1 because
35
+ Python requests module already tries twice
36
+ by default. Unlike JWT where we need to refresh
37
+ token every 10 seconds, general authenticators
38
+ wait for 60 seconds before connection timeout
39
+ per attempt totaling a 240 sec wait time for a non
40
+ JWT based authenticator which is more than enough.
41
+ This can be changed ofcourse using MAX_CNXN_RETRY_ATTEMPTS
42
+ env variable.
43
+ """
33
44
DEFAULT_MAX_CON_RETRY_ATTEMPTS = 1
34
-
35
-
36
- class AuthRetryCtx :
37
- def __init__ (self ) -> None :
38
- self ._current_retry_count = 0
39
- self ._max_retry_attempts = int (
40
- getenv ("MAX_CON_RETRY_ATTEMPTS" , DEFAULT_MAX_CON_RETRY_ATTEMPTS )
41
- )
42
- self ._backoff = DecorrelateJitterBackoff (1 , 16 )
43
- self ._current_sleep_time = 1
44
-
45
- def get_current_retry_count (self ) -> int :
46
- return self ._current_retry_count
47
-
48
- def increment_retry (self ) -> None :
49
- self ._current_retry_count += 1
50
-
51
- def should_retry (self ) -> bool :
52
- """Decides whether to retry connection.
53
-
54
- Default value for max retry is 1 because
55
- Python requests module already tries twice
56
- by default. Unlike JWT where we need to refresh
57
- token every 10 seconds, general authenticators
58
- wait for 60 seconds before connection timeout
59
- per attempt totaling a 240 sec wait time for a non
60
- JWT based authenticator which is more than enough.
61
- This can be changed ofcourse using MAX_CNXN_RETRY_ATTEMPTS
62
- env variable.
63
- """
64
- return self ._current_retry_count < self ._max_retry_attempts
65
-
66
- def next_sleep_duration (self ) -> int :
67
- self ._current_sleep_time = self ._backoff .next_sleep (
68
- self ._current_retry_count , self ._current_sleep_time
69
- )
70
- logger .debug (f"Sleeping for { self ._current_sleep_time } seconds" )
71
- return self ._current_sleep_time
72
-
73
- def reset (self ) -> None :
74
- self ._current_retry_count = 0
75
- self ._current_sleep_time = 1
45
+ DEFAULT_AUTH_CLASS_TIMEOUT = 120
76
46
77
47
78
48
@unique
@@ -89,18 +59,38 @@ class AuthType(Enum):
89
59
class AuthByPlugin (ABC ):
90
60
"""External Authenticator interface."""
91
61
92
- def __init__ (self ) -> None :
93
- self ._retry_ctx = AuthRetryCtx ()
62
+ def __init__ (
63
+ self ,
64
+ timeout : int | None = None ,
65
+ backoff_generator : Iterator | None = None ,
66
+ ** kwargs ,
67
+ ) -> None :
94
68
self .consent_cache_id_token = False
95
- self ._timeout : int = 120
69
+
70
+ self ._retry_ctx = TimeoutBackoffCtx (
71
+ timeout = timeout if timeout is not None else DEFAULT_AUTH_CLASS_TIMEOUT ,
72
+ max_retry_attempts = kwargs .get (
73
+ "max_retry_attempts" ,
74
+ int (getenv ("MAX_CON_RETRY_ATTEMPTS" , DEFAULT_MAX_CON_RETRY_ATTEMPTS )),
75
+ ),
76
+ backoff_generator = backoff_generator ,
77
+ )
78
+
79
+ # some authenticators may want to override socket level timeout
80
+ # for example, AuthByKeyPair will set this to ensure JWT tokens are refreshed in time
81
+ # if not None, this will override socket_timeout specified in connection
82
+ self ._socket_timeout = None
96
83
97
84
@property
98
85
def timeout (self ) -> int :
99
- return self ._timeout
86
+ """The timeout of _retry_ctx is guaranteed not to be None during AuthByPlugin initialization"""
87
+ return self ._retry_ctx .timeout
100
88
101
89
@timeout .setter
102
- def timeout (self , value : Any ) -> None :
103
- self ._timeout = int (value )
90
+ def timeout (self ) -> None :
91
+ logger .warning (
92
+ "Attempting to mutate timeout of AuthByPlugin. Create a new instance with desired parameters instead."
93
+ )
104
94
105
95
@property
106
96
@abstractmethod
@@ -207,20 +197,23 @@ def handle_timeout(
207
197
time ranges between 1 and 16 seconds.
208
198
"""
209
199
210
- del authenticator , service_name , account , user , password
200
+ # Some authenticators may not want to delete the parameters to this function
201
+ # Currently, the only authenticator where this is the case is AuthByKeyPair
202
+ if kwargs .pop ("delete_params" , True ):
203
+ del authenticator , service_name , account , user , password
204
+
211
205
logger .debug ("Default timeout handler invoked for authenticator" )
212
- if not self ._retry_ctx .should_retry () :
206
+ if not self ._retry_ctx .should_retry :
213
207
error = OperationalError (
214
- msg = f"Could not connect to Snowflake backend after { self ._retry_ctx .get_current_retry_count () } attempt(s)."
208
+ msg = f"Could not connect to Snowflake backend after { self ._retry_ctx .current_retry_count + 1 } attempt(s)."
215
209
"Aborting" ,
216
210
errno = ER_FAILED_TO_CONNECT_TO_DB ,
217
211
)
218
- self ._retry_ctx .reset ()
219
212
raise error
220
213
else :
221
214
logger .debug (
222
- f"Hit connection timeout, attempt number { self ._retry_ctx .get_current_retry_count () } ."
215
+ f"Hit connection timeout, attempt number { self ._retry_ctx .current_retry_count + 1 } ."
223
216
" Will retry in a bit..."
224
217
)
225
- self ._retry_ctx .increment_retry ( )
226
- time . sleep ( self ._retry_ctx .next_sleep_duration () )
218
+ time . sleep ( float ( self ._retry_ctx .current_sleep_time ) )
219
+ self ._retry_ctx .increment ( )
0 commit comments