Skip to content

Commit 2851d3b

Browse files
authored
Snow 916662 design retry timeout config in python connector (#1759)
1 parent cf817a4 commit 2851d3b

29 files changed

+777
-272
lines changed

DESCRIPTION.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1212

1313
- Added support for `use_logical_type` in `write_pandas`.
1414
- Removed dependencies on pycryptodomex and oscrypto. All connections now go through OpenSSL via the cryptography library, which was already a dependency.
15+
- Added the `backoff_policy` argument to `snowflake.connector.connect` allowing for configurable backoff policy between retries of failed requests. See available implementations in the `backoff_policies` module.
16+
- Added the `socket_timeout` argument to `snowflake.connector.connect` specifying socket read and connect timeout.
17+
- Fixed `login_timeout` and `network_timeout` behaviour. Retries of login and network requests are now properly halted after these timeouts expire.
1518

1619
- v3.3.1(October 16,2023)
1720

src/snowflake/connector/auth/_auth.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,9 @@ def base_auth_data(
127127
internal_application_name,
128128
internal_application_version,
129129
ocsp_mode,
130-
login_timeout,
131-
network_timeout=None,
130+
login_timeout: int | None = None,
131+
network_timeout: int | None = None,
132+
socket_timeout: int | None = None,
132133
):
133134
return {
134135
"data": {
@@ -148,6 +149,7 @@ def base_auth_data(
148149
"TRACING": logger.getEffectiveLevel(),
149150
"LOGIN_TIMEOUT": login_timeout,
150151
"NETWORK_TIMEOUT": network_timeout,
152+
"SOCKET_TIMEOUT": socket_timeout,
151153
},
152154
},
153155
}
@@ -166,10 +168,14 @@ def authenticate(
166168
mfa_callback: Callable[[], None] | None = None,
167169
password_callback: Callable[[], str] | None = None,
168170
session_parameters: dict[Any, Any] | None = None,
169-
timeout: int = 120,
171+
# max time waiting for MFA response, currently unused
172+
timeout: int | None = None,
170173
) -> dict[str, str | int | bool]:
171174
logger.debug("authenticate")
172175

176+
if timeout is None:
177+
timeout = auth_instance.timeout
178+
173179
if session_parameters is None:
174180
session_parameters = {}
175181

@@ -194,6 +200,7 @@ def authenticate(
194200
self._rest._connection._ocsp_mode(),
195201
self._rest._connection._login_timeout,
196202
self._rest._connection._network_timeout,
203+
self._rest._connection._socket_timeout,
197204
)
198205

199206
body = copy.deepcopy(body_template)
@@ -239,20 +246,12 @@ def authenticate(
239246
{k: v for (k, v) in body["data"].items() if k != "PASSWORD"},
240247
)
241248

242-
# accommodate any authenticator specific timeout requirements here.
243-
# login_timeout comes from user configuration.
244-
# Between login timeout and auth specific
245-
# timeout use whichever value is smaller
246-
auth_timeout = min(self._rest._connection.login_timeout, auth_instance.timeout)
247-
logger.debug(f"Timeout set to {auth_timeout}")
248-
249249
try:
250250
ret = self._rest._post_request(
251251
url,
252252
headers,
253253
json.dumps(body),
254-
timeout=auth_timeout,
255-
socket_timeout=auth_timeout,
254+
socket_timeout=auth_instance._socket_timeout,
256255
)
257256
except ForbiddenError as err:
258257
# HTTP 403
@@ -293,7 +292,10 @@ def authenticate(
293292
def post_request_wrapper(self, url, headers, body) -> None:
294293
# get the MFA response
295294
self.ret = self._rest._post_request(
296-
url, headers, body, timeout=self._rest._connection.login_timeout
295+
url,
296+
headers,
297+
body,
298+
socket_timeout=auth_instance._socket_timeout,
297299
)
298300

299301
# send new request to wait until MFA is approved
@@ -307,6 +309,7 @@ def post_request_wrapper(self, url, headers, body) -> None:
307309
while not self.ret or self.ret.get("message") == "Timeout":
308310
next(c)
309311
else:
312+
# _post_request should already terminate on timeout, so this is just a safeguard
310313
t.join(timeout=timeout)
311314

312315
ret = self.ret
@@ -322,8 +325,7 @@ def post_request_wrapper(self, url, headers, body) -> None:
322325
url,
323326
headers,
324327
json.dumps(body),
325-
timeout=self._rest._connection.login_timeout,
326-
socket_timeout=self._rest._connection.login_timeout,
328+
socket_timeout=auth_instance._socket_timeout,
327329
)
328330
elif not ret or not ret["data"] or not ret["data"].get("token"):
329331
# not token is returned.
@@ -363,8 +365,7 @@ def post_request_wrapper(self, url, headers, body) -> None:
363365
url,
364366
headers,
365367
json.dumps(body),
366-
timeout=self._rest._connection.login_timeout,
367-
socket_timeout=self._rest._connection.login_timeout,
368+
socket_timeout=auth_instance._socket_timeout,
368369
)
369370

370371
logger.debug("completed authentication")

src/snowflake/connector/auth/by_plugin.py

Lines changed: 50 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,61 +18,31 @@
1818
from abc import ABC, abstractmethod
1919
from enum import Enum, unique
2020
from os import getenv
21-
from typing import TYPE_CHECKING, Any
21+
from typing import TYPE_CHECKING, Any, Iterator
2222

2323
from ..errorcode import ER_FAILED_TO_CONNECT_TO_DB
2424
from ..errors import DatabaseError, Error, OperationalError
2525
from ..sqlstate import SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED
26-
from ..time_util import DecorrelateJitterBackoff
26+
from ..time_util import TimeoutBackoffCtx
2727

2828
if TYPE_CHECKING:
2929
from .. import SnowflakeConnection
3030

3131
logger = logging.getLogger(__name__)
3232

33+
"""
34+
Default value for max retry is 1 because
35+
Python requests module already tries twice
36+
by default. Unlike JWT where we need to refresh
37+
token every 10 seconds, general authenticators
38+
wait for 60 seconds before connection timeout
39+
per attempt totaling a 240 sec wait time for a non
40+
JWT based authenticator which is more than enough.
41+
This can be changed ofcourse using MAX_CNXN_RETRY_ATTEMPTS
42+
env variable.
43+
"""
3344
DEFAULT_MAX_CON_RETRY_ATTEMPTS = 1
34-
35-
36-
class AuthRetryCtx:
37-
def __init__(self) -> None:
38-
self._current_retry_count = 0
39-
self._max_retry_attempts = int(
40-
getenv("MAX_CON_RETRY_ATTEMPTS", DEFAULT_MAX_CON_RETRY_ATTEMPTS)
41-
)
42-
self._backoff = DecorrelateJitterBackoff(1, 16)
43-
self._current_sleep_time = 1
44-
45-
def get_current_retry_count(self) -> int:
46-
return self._current_retry_count
47-
48-
def increment_retry(self) -> None:
49-
self._current_retry_count += 1
50-
51-
def should_retry(self) -> bool:
52-
"""Decides whether to retry connection.
53-
54-
Default value for max retry is 1 because
55-
Python requests module already tries twice
56-
by default. Unlike JWT where we need to refresh
57-
token every 10 seconds, general authenticators
58-
wait for 60 seconds before connection timeout
59-
per attempt totaling a 240 sec wait time for a non
60-
JWT based authenticator which is more than enough.
61-
This can be changed ofcourse using MAX_CNXN_RETRY_ATTEMPTS
62-
env variable.
63-
"""
64-
return self._current_retry_count < self._max_retry_attempts
65-
66-
def next_sleep_duration(self) -> int:
67-
self._current_sleep_time = self._backoff.next_sleep(
68-
self._current_retry_count, self._current_sleep_time
69-
)
70-
logger.debug(f"Sleeping for {self._current_sleep_time} seconds")
71-
return self._current_sleep_time
72-
73-
def reset(self) -> None:
74-
self._current_retry_count = 0
75-
self._current_sleep_time = 1
45+
DEFAULT_AUTH_CLASS_TIMEOUT = 120
7646

7747

7848
@unique
@@ -89,18 +59,38 @@ class AuthType(Enum):
8959
class AuthByPlugin(ABC):
9060
"""External Authenticator interface."""
9161

92-
def __init__(self) -> None:
93-
self._retry_ctx = AuthRetryCtx()
62+
def __init__(
63+
self,
64+
timeout: int | None = None,
65+
backoff_generator: Iterator | None = None,
66+
**kwargs,
67+
) -> None:
9468
self.consent_cache_id_token = False
95-
self._timeout: int = 120
69+
70+
self._retry_ctx = TimeoutBackoffCtx(
71+
timeout=timeout if timeout is not None else DEFAULT_AUTH_CLASS_TIMEOUT,
72+
max_retry_attempts=kwargs.get(
73+
"max_retry_attempts",
74+
int(getenv("MAX_CON_RETRY_ATTEMPTS", DEFAULT_MAX_CON_RETRY_ATTEMPTS)),
75+
),
76+
backoff_generator=backoff_generator,
77+
)
78+
79+
# some authenticators may want to override socket level timeout
80+
# for example, AuthByKeyPair will set this to ensure JWT tokens are refreshed in time
81+
# if not None, this will override socket_timeout specified in connection
82+
self._socket_timeout = None
9683

9784
@property
9885
def timeout(self) -> int:
99-
return self._timeout
86+
"""The timeout of _retry_ctx is guaranteed not to be None during AuthByPlugin initialization"""
87+
return self._retry_ctx.timeout
10088

10189
@timeout.setter
102-
def timeout(self, value: Any) -> None:
103-
self._timeout = int(value)
90+
def timeout(self) -> None:
91+
logger.warning(
92+
"Attempting to mutate timeout of AuthByPlugin. Create a new instance with desired parameters instead."
93+
)
10494

10595
@property
10696
@abstractmethod
@@ -207,20 +197,23 @@ def handle_timeout(
207197
time ranges between 1 and 16 seconds.
208198
"""
209199

210-
del authenticator, service_name, account, user, password
200+
# Some authenticators may not want to delete the parameters to this function
201+
# Currently, the only authenticator where this is the case is AuthByKeyPair
202+
if kwargs.pop("delete_params", True):
203+
del authenticator, service_name, account, user, password
204+
211205
logger.debug("Default timeout handler invoked for authenticator")
212-
if not self._retry_ctx.should_retry():
206+
if not self._retry_ctx.should_retry:
213207
error = OperationalError(
214-
msg=f"Could not connect to Snowflake backend after {self._retry_ctx.get_current_retry_count()} attempt(s)."
208+
msg=f"Could not connect to Snowflake backend after {self._retry_ctx.current_retry_count + 1} attempt(s)."
215209
"Aborting",
216210
errno=ER_FAILED_TO_CONNECT_TO_DB,
217211
)
218-
self._retry_ctx.reset()
219212
raise error
220213
else:
221214
logger.debug(
222-
f"Hit connection timeout, attempt number {self._retry_ctx.get_current_retry_count()}."
215+
f"Hit connection timeout, attempt number {self._retry_ctx.current_retry_count + 1}."
223216
" Will retry in a bit..."
224217
)
225-
self._retry_ctx.increment_retry()
226-
time.sleep(self._retry_ctx.next_sleep_duration())
218+
time.sleep(float(self._retry_ctx.current_sleep_time))
219+
self._retry_ctx.increment()

src/snowflake/connector/auth/default.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def type_(self) -> AuthType:
2121
def assertion_content(self) -> str:
2222
return "*********"
2323

24-
def __init__(self, password: str) -> None:
24+
def __init__(self, password: str, **kwargs) -> None:
2525
"""Initializes an instance with a password."""
26-
super().__init__()
26+
super().__init__(**kwargs)
2727
self._password: str | None = password
2828

2929
def reset_secrets(self) -> None:

src/snowflake/connector/auth/idtoken.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ def __init__(
3636
protocol: str | None,
3737
host: str | None,
3838
port: str | None,
39+
**kwargs,
3940
) -> None:
4041
"""Initialized an instance with an IdToken."""
41-
super().__init__()
42+
super().__init__(**kwargs)
4243
self._id_token: str | None = id_token
4344
self._application = application
4445
self._protocol = protocol
@@ -62,6 +63,8 @@ def reauthenticate(
6263
protocol=self._protocol,
6364
host=self._host,
6465
port=self._port,
66+
timeout=conn.login_timeout,
67+
backoff_generator=conn._backoff_generator,
6568
)
6669
conn._authenticate(conn.auth_class)
6770
conn._auth_class.reset_secrets()

0 commit comments

Comments
 (0)